In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as f
from torchvision import datasets, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split

import matplotlib.pyplot as plt
import numpy as np


# Model

In [2]:
class ConvBlocks(nn.Module):
  def __init__(self,in_channels, intermediate_channels, identity_downsample=None,stride=1):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels, intermediate_channels, kernel_size=1, stride=1,padding=0,bias=False)
    self.bn1 = nn.BatchNorm2d(intermediate_channels)
    self.conv2 = nn.Conv2d(intermediate_channels, intermediate_channels, 3, stride=stride,padding = 1, bias=False)
    self.bn2 = nn.BatchNorm2d(intermediate_channels)
    self.conv3 = nn.Conv2d(intermediate_channels, intermediate_channels*4, 1, 1, 0, bias=False)
    self.bn3 = nn.BatchNorm2d(intermediate_channels*4)
    self.relu = nn.ReLU()
    self.identity_downsample = identity_downsample

  def forward(self, x):
    identity = x.clone()

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = self.conv3(x)
    x = self.bn3(x)

    if self.identity_downsample is not None:
      identity = self.identity_downsample(identity)
    x += identity
    x = self.relu(x)
    return x

class ResNet(nn.Module):
  def __init__(self, block, layers, img_channels, num_classes):
    super(ResNet,self).__init__()

    self.in_channels = 64

    self.conv1 = nn.Conv2d(img_channels, 64, 7, 2, 3, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()
    self.maxpool = nn.MaxPool2d(3, 2, 1)
    self.layer1 = self._make_layer(
            block, layers[0], intermediate_channels=64, stride=1
        )
    self.layer2 = self._make_layer(
            block, layers[1], intermediate_channels=128, stride=2
        )
    self.layer3 = self._make_layer(
            block, layers[2], intermediate_channels=256, stride=2
        )
    self.layer4 = self._make_layer(
            block, layers[3], intermediate_channels=512, stride=2
        )

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * 4, num_classes)

  def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

  def _make_layer(self, block, num_residual_blocks, intermediate_channels, stride):
    identity_downsample = None
    layers = []

    if stride != 1 or self.in_channels != intermediate_channels*4:
      identity_downsample = nn.Sequential(
          nn.Conv2d(self.in_channels, intermediate_channels*4, 1, stride = stride, bias=False),
          nn.BatchNorm2d(intermediate_channels*4),
      )
    layers.append(block(self.in_channels, intermediate_channels, identity_downsample, stride))
    self.in_channels = intermediate_channels *4

    for i in range(num_residual_blocks - 1):
      layers.append(block(self.in_channels, intermediate_channels))
    return nn.Sequential(*layers)


In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCH = 70
PATH = "resnet_checkpoint.pth"
BATCH_SIZE = 64
LR = 0.0001

In [7]:
def ResNet50(img_channel=3, num_classes=1000):
    return ResNet(ConvBlocks, [3, 4, 6, 3], img_channel, num_classes)


def ResNet101(img_channel=3, num_classes=1000):
    return ResNet(ConvBlocks, [3, 4, 23, 3], img_channel, num_classes)


def ResNet152(img_channel=3, num_classes=1000):
    return ResNet(ConvBlocks, [3, 8, 36, 3], img_channel, num_classes)


In [11]:
!unzip /content/drive/MyDrive/data/Mammals_Images.zip -d /content

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/mammals/red_panda/red_panda-0250.jpg  
  inflating: /content/mammals/red_panda/red_panda-0251.jpg  
  inflating: /content/mammals/red_panda/red_panda-0252.jpg  
  inflating: /content/mammals/red_panda/red_panda-0253.jpg  
  inflating: /content/mammals/red_panda/red_panda-0254.jpg  
  inflating: /content/mammals/red_panda/red_panda-0255.jpg  
  inflating: /content/mammals/red_panda/red_panda-0256.jpg  
  inflating: /content/mammals/red_panda/red_panda-0257.jpg  
  inflating: /content/mammals/red_panda/red_panda-0258.jpg  
  inflating: /content/mammals/red_panda/red_panda-0259.jpg  
  inflating: /content/mammals/red_panda/red_panda-0260.jpg  
  inflating: /content/mammals/red_panda/red_panda-0261.jpg  
  inflating: /content/mammals/red_panda/red_panda-0262.jpg  
  inflating: /content/mammals/red_panda/red_panda-0263.jpg  
  inflating: /content/mammals/red_panda/red_panda-0264.jpg  
  inflating: /conten

In [12]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(root='/content/mammals', transform = transform)


In [13]:
train_size = int(0.96*len(dataset))
test_size = int(len(dataset)-train_size)
train_size,test_size

(13200, 551)

In [14]:
train_dataset, test_dataset = random_split(dataset , [train_size, test_size ])

In [15]:
train_loader = DataLoader(train_dataset , batch_size= BATCH_SIZE,shuffle = True, num_workers = 2, pin_memory = True, drop_last = True)
test_loader = DataLoader(test_dataset , batch_size= 32,shuffle = True)

In [16]:
model = ResNet50(num_classes = 45).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr= LR)
criterion = nn.CrossEntropyLoss()

In [17]:
from torchsummary import summary
summary(model, (3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
            Conv2d-7           [-1, 64, 56, 56]          36,864
       BatchNorm2d-8           [-1, 64, 56, 56]             128
            Conv2d-9          [-1, 256, 56, 56]          16,384
      BatchNorm2d-10          [-1, 256, 56, 56]             512
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
             ReLU-13          [-1, 256, 56, 56]               0
       ConvBlocks-14          [-1, 256,

In [19]:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [20]:
best_accuracy = 0.0
for epoch in range(EPOCH):
  epoch_loss = 0.0
  correct_prediction = 0
  total_samples = 0
  for data,targets in tqdm(train_loader):
      data = data.to(DEVICE)
      targets = targets.to(DEVICE)

      optimizer.zero_grad()
      scores = model(data)
      loss = criterion(scores, targets)

      loss.backward()
      optimizer.step()

      epoch_loss += loss.item()
      _,pred = torch.max(scores, 1)
      correct_prediction += torch.sum(pred == targets).item()
      total_samples += targets.size(0)
  epoch_accuracy = correct_prediction / total_samples
  epoch_loss /= len(train_loader)

  print('Epoch: {} \tLoss: {:.4f} \tAcc: {:.4f}'.format(epoch + 1, epoch_loss, epoch_accuracy))
  if epoch_accuracy > best_accuracy:
        best_accuracy = epoch_accuracy
        torch.save({'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_accuracy': best_accuracy}, PATH)


100%|██████████| 206/206 [02:08<00:00,  1.60it/s]


Epoch: 1 	Loss: 0.7572 	Acc: 0.7753


 17%|█▋        | 35/206 [00:22<01:49,  1.57it/s]


KeyboardInterrupt: 