In [1]:
import torch 
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
transform_train= transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
])
transform_test = transforms.Compose([
    transforms.ToTensor()
])

In [6]:
import torchvision 
trainset = torchvision.datasets.CIFAR100(root='/.data', train=True, download=True, transform=transform_train)
trainloader  = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

100%|██████████| 169M/169M [00:03<00:00, 43.2MB/s] 


In [7]:
testset = torchvision.datasets.CIFAR100(root='/.data', train=False, download=True, transform=transform_test)
testloader  = DataLoader(trainset, batch_size=64, shuffle=False, num_workers=2)

In [8]:
len(trainset), len(testset)

(50000, 10000)

In [9]:
class_names  = trainset.classes
class_names[:10]

['apple',
 'aquarium_fish',
 'baby',
 'bear',
 'beaver',
 'bed',
 'bee',
 'beetle',
 'bicycle',
 'bottle']

In [10]:
model = torchvision.models.resnet18(weights=False).to(device)
model.fc  = nn.Linear(model.fc.in_features, 100)



In [11]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 2.8876e-02, -3.0156e-02,  2.9927e-02,  ...,  4.9592e-02,
                         -6.9721e-02, -1.9763e-02],
                        [-4.0642e-02,  5.0081e-02,  1.7049e-02,  ...,  2.5307e-02,
                         -2.6445e-02, -3.7567e-02],
                        [-2.3175e-02, -3.1586e-02, -2.7729e-02,  ..., -6.4584e-02,
                          7.3988e-04, -6.1232e-03],
                        ...,
                        [-4.0148e-02,  5.3276e-03,  7.4975e-03,  ...,  2.9128e-02,
                         -1.3143e-02, -7.7447e-03],
                        [ 2.0650e-02, -4.1864e-02,  2.9389e-02,  ..., -1.2972e-02,
                         -1.0708e-03,  6.5026e-03],
                        [-3.3814e-02,  4.7434e-02, -2.2855e-02,  ...,  2.1980e-02,
                         -8.7269e-03,  8.4399e-03]],
              
                       [[ 1.3807e-02,  2.4859e-02, -9.1523e-03,  ..., -1.0156e-03,
                          1.1647

In [14]:
import torch.optim as optim
loss_fn  = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2, eta_min=0.01)

In [None]:
def accuracy_fn(y_true, y_pred):
    correct = torch.ep(y_pred, y_true).sum().item()
    acc = correct/len(y_pred)
    return acc