In [1]:
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor, Resize, Compose, Normalize
from torchvision.datasets import CIFAR10
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from torch.utils.data import DataLoader
from torchvision.models import wide_resnet50_2 as wrn50

In [2]:
?wrn50

[0;31mSignature:[0m [0mwrn50[0m[0;34m([0m[0mpretrained[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0mprogress[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.

Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
    progress (bool): If True, displays a progress bar of the download to stderr
[0;31mFile:[0m      ~/Projects/venv/lib/python3.8/site-packages/torchvision/models/resnet.py
[0;31mType:[0m      function


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
#model = create_model('vit_base_patch16_384', pretrained=True)
#model = create_model('tv_resnet101', pretrained=True)
model = wrn50(pretrained=True)

In [5]:
model.fc = nn.Linear(2048, 10, bias=True)

In [6]:
#model.reset_classifier(num_classes=10)

In [7]:
transforms = Compose([Resize(384), ToTensor(), Normalize(mean=(0.5,0.5, 0.5), std=(0.5, 0.5, 0.5))])
cifar10 = CIFAR10(root='./datasets/train/', download=True, train=True, transform=transforms)

Files already downloaded and verified


In [8]:
cifar10_test = CIFAR10(root='./datasets/test/', download=True, train=False, transform=transforms)

Files already downloaded and verified


In [9]:
train_dl = DataLoader(cifar10, batch_size=16, drop_last=True)
test_dl = DataLoader(cifar10_test, batch_size=16, drop_last=True)

In [10]:
model = model.to(device)

In [11]:
loss_fn = nn.CrossEntropyLoss()

In [12]:
optim = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=5e-4)
#sched = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.95)
best_acc = 0.0

In [15]:
for idx, (img, label) in enumerate(train_dl):
    model.train()
    if idx > 40000:
        break
    optim.zero_grad()
    
    img = img.to(device)
    label = label.to(device)
    pred = model(img)
    loss_val = loss_fn(pred, label)
    loss_val.backward()
    optim.step()
    torch.save(model.state_dict(), './cifar10_tv_resnet101.pth')
    if idx % 100 == 0:
        model.eval()
        correct = 0.0
        with torch.no_grad():
            for timg, tlabel in test_dl:
                timg = timg.to(device)
                tlabel = tlabel.to(device)
                tout = model(timg)
                loss_t = loss_fn(tout, tlabel)
                correct += (tlabel == torch.argmax(tout, dim=1)).sum().item()
            acc = correct/10000.0
            #sched.step()
            if acc >= best_acc:
                print('saving model')
                torch.save(model.state_dict(), './cifar10_tv_resnet101.pth')
                best_acc = acc
        print(f'Idx:{idx}, Train_loss:{loss_val.item()}, Test loss:{loss_t.item()}, test accuracy:{acc:.2f}')
    

Idx:0, Train_loss:0.8777157068252563, Test loss:1.2062662839889526, test accuracy:0.58
Idx:100, Train_loss:1.4707999229431152, Test loss:1.4834874868392944, test accuracy:0.55
Idx:200, Train_loss:1.6017284393310547, Test loss:0.8672785758972168, test accuracy:0.58
Idx:300, Train_loss:0.9733362793922424, Test loss:1.4643990993499756, test accuracy:0.57
Idx:400, Train_loss:1.1839113235473633, Test loss:1.1384563446044922, test accuracy:0.60
Idx:500, Train_loss:1.2378023862838745, Test loss:1.4023743867874146, test accuracy:0.55
Idx:600, Train_loss:1.4214180707931519, Test loss:1.5212477445602417, test accuracy:0.56
Idx:700, Train_loss:0.8655726909637451, Test loss:0.905354917049408, test accuracy:0.61
Idx:800, Train_loss:1.1174672842025757, Test loss:1.3106553554534912, test accuracy:0.61
Idx:900, Train_loss:1.0688982009887695, Test loss:1.202191948890686, test accuracy:0.61
Idx:1000, Train_loss:1.5668672323226929, Test loss:1.2677035331726074, test accuracy:0.57
Idx:1100, Train_loss:1.0