<a href="https://colab.research.google.com/github/Wenjie0o0/ClassAI/blob/main/Resnet50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

In [39]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [40]:
num_epochs = 2
batch_size = 64
learning_rate = 1e-4 #0.0001

In [41]:
transform = transforms.Compose(
    [ transforms.Resize([224,224]),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [68]:
class Resnet(nn.Module):
  def __init__(self,in_channels=64,num_classes=10):
    super().__init__()
    self.in_channels=in_channels
    
    self.conv1=conv_block(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)
    self.relu=nn.ReLU()
    self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

    self.layer1=self.make_layer(64,3,stride=1)
    self.layer2=self.make_layer(128,4,stride=2)
    self.layer3=self.make_layer(256,6,stride=2)
    self.layer4=self.make_layer(512,3,stride=2)

    self.avgpool=nn.AvgPool2d(7)
    self.fc=nn.Linear(2048,num_classes)

  def forward(self,x):
    x=self.conv1(x)
    x=self.relu(x)
    x=self.maxpool(x)

    x=self.layer1(x)
    x=self.layer2(x)
    x=self.layer3(x)
    x=self.layer4(x)

    x=self.avgpool(x)
    x=x.reshape(x.shape[0],-1)
    out=self.fc(x)
    return out

  def make_layer(self,in_channels,block_num,stride=1):
    block_list=[]
    downsample=None
    if(stride!=1 or self.in_channels!=in_channels*4):
      downsample=nn.Sequential(
          conv_block(self.in_channels,in_channels*4,stride=stride,kernel_size=1)
      )
      block1=Bottleneck(self.in_channels,in_channels,stride=stride,downsample=downsample)
      block_list.append(block1)
      self.in_channels=in_channels*4

      for i in range(1,block_num):
        block_list.append(Bottleneck(self.in_channels,in_channels,stride=1))
      return nn.Sequential(*block_list)


class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return self.batchnorm(self.conv(x))

class Bottleneck(nn.Module):
  def __init__(self,in_channels,out_channels,stride,downsample=None):
    super().__init__()
    self.conv1=conv_block(in_channels,out_channels,kernel_size=1,stride=stride)
    self.conv2=conv_block(out_channels,out_channels,kernel_size=3,stride=1)
    self.conv3=conv_block(out_channels,out_channels*4,kernel_size=1,stride=1)
    self.relu=nn.ReLU()

    self.downsample=downsample
    self.stride=stride

  def forward(self,x):
    residual=x
    out=self.conv1(x)
    out=self.relu(out)
    out=self.conv2(out)
    out=self.relu(out)
    out=self.conv3(out)
    out=self.relu(out)

    if self.downsample is not None:
      residual=self.downsample(x)
    out+=residual
    out=self.relu(out)

    return out

In [69]:
resnet = Resnet().to(device)

In [70]:
print(resnet)

Resnet(
  (conv1): conv_block(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): conv_block(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): conv_block(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv3): conv_block(
        (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU()
      (downs

In [71]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=learning_rate)

In [72]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(resnet)

23555082

In [73]:
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(train_loader):
    # origin shape: [64, 3, 224, 224]
    images = images.to(device)
    labels = labels.to(device)
    # Forward pass
    
    outputs = resnet(images)
    loss = criterion(outputs, labels)
    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
        print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

RuntimeError: ignored

In [None]:
with torch.no_grad():
  n_correct = 0
  n_samples = 0
  n_class_correct = [0 for i in range(10)]
  n_class_samples = [0 for i in range(10)]
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    outputs = resnet(images)
    # max returns (value ,index)
    _, predicted = torch.max(outputs, 1)
    n_samples += labels.size(0)
    n_correct += (predicted == labels).sum().item()

  acc = 100.0 * n_correct / n_samples
  print(f'Accuracy of the network: {acc} %')