## ResNet Implementation in PyTorch

Credit goes to Jay Patwardhan for providing an example implementation of ResNet 50, 101, and 152 in PyTorch. This work is inspired by the paper "Deep Residual Learning for Image Recognition" by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import torchvision.models as models
r3d_18 = models.video.r3d_18(pretrained=True)

In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from ResNet import Bottleneck, ResNet, ResNet50

In [6]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [7]:
train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=2)

test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

testloader = torch.utils.data.DataLoader(test, batch_size=128,shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 110573227.89it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [9]:
net = ResNet50(10).to('cuda')

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.1, patience=5)

In [None]:
EPOCHS = 200
for epoch in range(EPOCHS):
    losses = []
    running_loss = 0
    for i, inp in enumerate(trainloader):
        inputs, labels = inp
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
    
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i%100 == 0 and i > 0:
            print(f'Loss [{epoch+1}, {i}](epoch, minibatch): ', running_loss / 100)
            running_loss = 0.0

    avg_loss = sum(losses)/len(losses)
    scheduler.step(avg_loss)
            
print('Training Done')

Loss [1, 100](epoch, minibatch):  8.286106350421905
Loss [1, 200](epoch, minibatch):  3.2815474700927734
Loss [1, 300](epoch, minibatch):  2.5430442595481875
Loss [2, 100](epoch, minibatch):  2.323899998664856
Loss [2, 200](epoch, minibatch):  2.2045424485206606
Loss [2, 300](epoch, minibatch):  2.104859755039215
Loss [3, 100](epoch, minibatch):  2.020149667263031
Loss [3, 200](epoch, minibatch):  1.9646476566791535
Loss [3, 300](epoch, minibatch):  1.921252702474594
Loss [4, 100](epoch, minibatch):  1.8857625865936278
Loss [4, 200](epoch, minibatch):  1.829829285144806
Loss [4, 300](epoch, minibatch):  1.8138708925247193
Loss [5, 100](epoch, minibatch):  1.7948732697963714
Loss [5, 200](epoch, minibatch):  1.7556228053569793
Loss [5, 300](epoch, minibatch):  1.7355257058143616
Loss [6, 100](epoch, minibatch):  1.7004099762439728
Loss [6, 200](epoch, minibatch):  1.655918046236038
Loss [6, 300](epoch, minibatch):  1.6366480994224548
Loss [7, 100](epoch, minibatch):  1.6395927166938782


In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to('cuda'), labels.to('cuda')
        outputs = net(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy on 10,000 test images: ', 100*(correct/total), '%')