In [1]:
import torch
import torch.nn as nn
import tqdm
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import ToTensor, Compose, RandomHorizontalFlip, RandomCrop, Normalize
from torch.utils.data import DataLoader
from torch.optim import Adam

In [2]:
class BasicBlock(nn.Module):
  def __init__(self,in_channels,out_channels,kernel_size=3):
    super(BasicBlock,self).__init__()

    self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=1)
    self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=kernel_size,padding=1)

    self.downsample=nn.Conv2d(in_channels,out_channels,kernel_size=1)

    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)

    self.relu=nn.ReLU()

  def forward(self,x):
    x_=x

    x=self.conv1(x)
    x=self.bn1(x)
    x=self.relu(x)

    x=self.conv2(x)
    x=self.bn2(x)

    x_ = self.downsample(x_)

    x +=x_
    x=self.relu(x)

    return x

In [3]:
class Resnet(nn.Module):
  def __init__(self,num_classes=10):
    super(Resnet,self).__init__()

    self.b1=BasicBlock(in_channels=3,out_channels=64)
    self.b2=BasicBlock(in_channels=64,out_channels=128)
    self.b3=BasicBlock(in_channels=128,out_channels=256)

    self.pool=nn.AvgPool2d(kernel_size=2,stride=2)

    self.fc1=nn.Linear(in_features=4096,out_features=2048)
    self.fc2=nn.Linear(in_features=2048,out_features=512)
    self.fc3=nn.Linear(in_features=512,out_features=num_classes)

    self.relu=nn.ReLU()

  def forward(self,x):
    x=self.b1(x)
    x=self.pool(x)

    x=self.b2(x)
    x=self.pool(x)

    x=self.b3(x)
    x=self.pool(x)

    x=torch.flatten(x,start_dim=1)

    x=self.fc1(x)
    x=self.relu(x)
    x=self.fc2(x)
    x=self.relu(x)
    x=self.fc3(x)

    return x

In [4]:
transforms=Compose([
    RandomHorizontalFlip(p=0.5),
    RandomCrop((32,32),padding=4),
    ToTensor(),
    Normalize(mean=(0.4914,0.4822,0.4465),std=(0.247,0.243,0.261))])

training_data=CIFAR10(root='./data',train=True,download=True,transform=transforms)
test_data=CIFAR10(root='./data',train=False,download=True,transform=transforms)

train_loader=DataLoader(training_data,batch_size=32,shuffle=True)
test_loader=DataLoader(test_data,batch_size=32,shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 42.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model=Resnet(num_classes=10)
model.to(device)

lr=1e-4
optim=Adam(model.parameters(),lr=lr)

for epoch in range(30):
  iter=tqdm.tqdm(train_loader)
  for data, label in iter:
    optim.zero_grad()

    preds=model(data.to(device))

    loss=nn.CrossEntropyLoss()(preds,label.to(device))
    loss.backward()
    optim.step()

    iter.set_description(f'Epoch {epoch+1}')
    iter.set_postfix({'loss':loss.item()})

torch.save(model.state_dict(),'./Resnet_CIFAR10.pth')

Epoch 1: 100%|██████████| 1563/1563 [01:09<00:00, 22.55it/s, loss=1.33]
Epoch 2: 100%|██████████| 1563/1563 [00:56<00:00, 27.58it/s, loss=1.17]
Epoch 3: 100%|██████████| 1563/1563 [00:59<00:00, 26.42it/s, loss=0.774]
Epoch 4: 100%|██████████| 1563/1563 [00:50<00:00, 30.72it/s, loss=0.677]
Epoch 5: 100%|██████████| 1563/1563 [00:53<00:00, 29.38it/s, loss=0.714]
Epoch 6: 100%|██████████| 1563/1563 [00:48<00:00, 32.13it/s, loss=0.638]
Epoch 7: 100%|██████████| 1563/1563 [00:49<00:00, 31.51it/s, loss=0.383]
Epoch 8: 100%|██████████| 1563/1563 [00:48<00:00, 32.50it/s, loss=0.785]
Epoch 9: 100%|██████████| 1563/1563 [00:48<00:00, 32.49it/s, loss=0.0968]
Epoch 10: 100%|██████████| 1563/1563 [00:48<00:00, 32.35it/s, loss=0.483]
Epoch 11: 100%|██████████| 1563/1563 [00:48<00:00, 32.45it/s, loss=0.129]
Epoch 12: 100%|██████████| 1563/1563 [00:48<00:00, 32.16it/s, loss=0.515]
Epoch 13: 100%|██████████| 1563/1563 [00:48<00:00, 32.21it/s, loss=0.59]
Epoch 14: 100%|██████████| 1563/1563 [00:47<00:00

In [8]:
model.load_state_dict(torch.load('./Resnet_CIFAR10.pth',map_location=device))

num_corr=0

with torch.no_grad():
  for data, label in iter:
    output=model(data.to(device))
    preds=output.data.max(1)[1]
    corr=preds.eq(label.to(device).data).sum().item()
    num_corr+=corr

print(f'Accuracy:{num_corr/len(test_data)}')

  model.load_state_dict(torch.load('./Resnet_CIFAR10.pth',map_location=device))


Accuracy:0.885
