In [1]:
import numpy as np 
import cv2
from PIL import Image
import matplotlib.pyplot as plt 
import os
import torch
import torch.nn as nn
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import ToTensor
import torchvision.transforms as T 
from torchvision.transforms import Compose
from torchvision.transforms import RandomHorizontalFlip, RandomCrop,Normalize
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam
device='cuda:0' if torch.cuda.is_available() else 'cpu'

class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3):
        super(BasicBlock,self).__init__()
        
        self.c1=nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=1)
        self.c2=nn.Conv2d(out_channels,out_channels,kernel_size=kernel_size,padding=1)
        
        self.downsample=nn.Conv2d(in_channels,out_channels,kernel_size=1)
        self.bn1=nn.BatchNorm2d(num_features=out_channels)
        self.bn2=nn.BatchNorm2d(num_features=out_channels)
        self.relu=nn.ReLU()
        
    def forward(self,x):
        x_=x
        
        x=self.c1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.c2(x)
        x=self.bn2(x)
        x_=self.downsample(x_)
        x+=x_
        x=self.relu(x)
        return x
        
    

In [2]:


class ResNet(nn.Module):
    def __init__(self,num_classes=10):
        super(ResNet,self).__init__()
        
        self.b1=BasicBlock(in_channels=3,out_channels=64)
        self.b2=BasicBlock(in_channels=64,out_channels=128)
        self.b3=BasicBlock(in_channels=128,out_channels=256)
        
        self.pool=nn.AvgPool2d(kernel_size=2,stride=2)
        
        self.fc1=nn.Linear(in_features=4096,out_features=2048)
        self.fc2=nn.Linear(in_features=2048,out_features=512)
        self.fc3=nn.Linear(in_features=512,out_features=num_classes)
        
        self.relu=nn.ReLU()
        
    def forward(self,x):
        x=self.b1(x)
        x=self.pool(x)
        x=self.b2(x)
        x=self.pool(x)
        x=self.b3(x)
        x=self.pool(x)
        
        x=torch.flatten(x,start_dim=1)
        
        x=self.fc1(x)
        x=self.relu(x)
        x=self.fc2(x)
        x=self.relu(x)
        x=self.fc3(x)
        
        return x
        
        
        

In [3]:
from torchvision.transforms import Resize,RandomHorizontalFlip,RandomCrop
transforms=Compose([
    RandomCrop((32,32),padding=4),
    RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    Normalize(mean=(0.4914,0.4822,0.4465),std=(0.247,0.243,0.261)),
])

training_data=CIFAR10(
    root='../data/',
    train=True,
    download=True,
    transform=transforms
)
test_data=CIFAR10(
    root='../data/',
    train=False,
    download=True,
    transform=transforms
)

train_loader=DataLoader(training_data,batch_size=2048,shuffle=True)
test_loader=DataLoader(test_data,batch_size=2048,shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model=ResNet(num_classes=10)
model.to(device)
lr = 1e-4
import tqdm
optim = Adam(model.parameters(), lr=lr)

for epoch in range(30):
    iterator = tqdm.notebook.tqdm(train_loader)
    for data, label in iterator:
        optim.zero_grad()
        preds = model(data.to(device))
        loss = nn.CrossEntropyLoss()(preds, label.to(device))
        loss.backward()
        optim.step()
        iterator.set_description(f"epoch: {epoch+1} loss: {loss.item()}")
torch.save(model.state_dict(), "../model/ResNet.pth")

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

In [5]:
model.load_state_dict(torch.load("../model/ResNet.pth",map_location=device))
num_corr=0
with torch.no_grad():
    iterator = tqdm.tqdm_notebook(test_loader)
    for data,label in iterator:
        output=model(data.to(device))
        preds=output.data.max(1)[1]
        corr=preds.eq(label.to(device).data).sum().item()
        num_corr+=corr

    print(f"Accuracy:{num_corr/len(test_data)}")

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  iterator = tqdm.tqdm_notebook(test_loader)


  0%|          | 0/5 [00:00<?, ?it/s]

Accuracy:0.8009


In [12]:
import torchsummary
torchsummary.summary(model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
            Conv2d-6           [-1, 64, 32, 32]             256
              ReLU-7           [-1, 64, 32, 32]               0
        BasicBlock-8           [-1, 64, 32, 32]               0
         AvgPool2d-9           [-1, 64, 16, 16]               0
           Conv2d-10          [-1, 128, 16, 16]          73,856
      BatchNorm2d-11          [-1, 128, 16, 16]             256
             ReLU-12          [-1, 128, 16, 16]               0
           Conv2d-13          [-1, 128, 16, 16]         147,584
      BatchNorm2d-14          [-1, 128,