<a href="https://colab.research.google.com/github/UOS-COMP6252/public/blob/main/lecture5/conv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Convolution Networks

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import SGD,Adam

### Convolution Network for CIFAR10

In [None]:
# select the device 
#to ensure some reproducibility 
seed=9 
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True
    device=torch.device('cuda')
else:
    device=torch.device('cpu')

In [None]:
transform = transforms.ToTensor()
dataset_train=vision.datasets.CIFAR10(".",download=True,train=True,transform=transform)
dataset_test=vision.datasets.CIFAR10(".",download=True,train=False,transform=transform)
loader_train=DataLoader(dataset_train,batch_size=64,shuffle=True,num_workers=2)
loader_test=DataLoader(dataset_test,batch_size=512,shuffle=False)

In [None]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    # input is (*,3,32,32)
    self.conv1=nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3)
    self.relu=nn.ReLU()
    # input is (*,32,30,30)
    self.conv2=nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3)
    # input is (*,32,28,28)
    self.pool1=nn.MaxPool2d(kernel_size=(2,2))
    # input is (*,32,14,14)
    self.conv3=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3)
    # input is (*,64,12,12)
    self.conv4=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
    # input is (*,64,10,10)
    self.pool2=nn.MaxPool2d(kernel_size=(2,2))
    # input is (*,64,5,5)
    self.flatten=nn.Flatten()
    # input is (*,64x5x5)
    self.fc1=nn.Linear(in_features=5*5*64,out_features=128)
    self.fc2=nn.Linear(in_features=128,out_features=10)

  def forward(self,x):
    x=self.conv1(x)
    x=self.relu(x)
    x=self.conv2(x)
    x=self.relu(x)
    x=self.pool1(x)
    
    x=self.conv3(x)
    x=self.relu(x)
    x=self.conv4(x)
    x=self.relu(x)
    x=self.pool2(x)
    
    x=self.flatten(x)
    x=self.fc1(x)
    x=self.relu(x)
    x=self.fc2(x)
    return x
    

In [None]:
def get_accuracy(dataloader,model,device):
  total=len(dataloader.dataset.data)
  correct=0
  for data in dataloader:
    imgs,labels=data
    imgs=imgs.to(device)
    labels=labels.to(device)
    outputs=model(imgs)
  # the second return value is the index of the max i.e. argmax
    _,predicted=torch.max(outputs.data,1)
    correct+=(predicted==labels).sum()
  

  return (correct/total).item()

In [None]:
model=Net().to(device)
optimizer=Adam(model.parameters())
loss_fn=nn.CrossEntropyLoss()

In [None]:
epochs=20
from tqdm import tqdm
for epoch in range(epochs):
  loop=tqdm(loader_train)
  loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
  epoch_loss=0.
  for (imgs,labels) in loop:
    optimizer.zero_grad()
    imgs=imgs.to(device)
    labels=labels.to(device)
    outputs=model(imgs)
    loss=loss_fn(outputs,labels)
    loss.backward()
    optimizer.step()
    epoch_loss=0.9*epoch_loss+0.1*loss.item()
    loop.set_postfix(loss=epoch_loss)
  t_acc=get_accuracy(loader_train,model,device) 
  v_acc=get_accuracy(loader_test,model,device)

In [None]:
try:
   from torchmetrics import ConfusionMatrix
except: 
    !pip install torchmetrics
    from torchmetrics import ConfusionMatrix

conmat=ConfusionMatrix(task='multiclass',num_classes=10)
conmat=conmat.to(device)

In [None]:
total=0
correct=0
for data in loader_test:
  imgs,labels=data
  imgs=imgs.to(device)
  labels=labels.to(device)
  outputs=model(imgs)
  # the second return value is the index of the max i.e. argmax
  _,predicted=torch.max(outputs.data,1)
  correct+=(predicted==labels).sum()
  total+=labels.size()[0]
  conmat.update(predicted,labels)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sb
x=conmat.compute().cpu().numpy()
plt.figure(figsize=(10,7))
sb.heatmap(x,xticklabels=dataset_train.classes,yticklabels=dataset_train.classes,annot=True,fmt=".0f")

In [None]:
from torchvision.models.feature_extraction import get_graph_node_names,create_feature_extractor

In [None]:
names=get_graph_node_names(model)[0]

In [None]:
names

In [None]:
return_nodes={'conv1':'layer1','conv2':'layer2','conv3':'layer3','conv4':'layer4'}

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
f=create_feature_extractor(model,return_nodes)

![Image Description](image_path)


In [None]:
def plot_activations(activation):
    n_row = 8
    n_column = activation.shape[-1]//n_row
    
    f, ax = plt.subplots(n_row, n_column)
    for i in range(n_row):
        for j in range(n_column):
            channel_image = activation[:, :, i*n_column+j]
            # image post-processing for better visualization
            # channel_image -= channel_image.mean()
            # channel_image /= channel_image.std()
            channel_image *= 255
            #channel_image += 128
            channel_image = np.clip(channel_image, 0, 255).astype('uint8')
        
            #ax[i, j].imshow(channel_image, cmap='viridis')
            #ax[i, j].imshow(channel_image, cmap='coolwarm')
            #ax[i, j].imshow(channel_image, cmap='Greys')
            ax[i, j].imshow(channel_image, cmap='bwr')
            #ax[i, j].imshow(channel_image, cmap='inferno')
            #ax[i, j].imshow(channel_image, cmap='copper')
            ax[i, j].axis('off')
            ax[i, j].set_xticklabels([])
            ax[i, j].set_yticklabels([])
    plt.subplots_adjust(wspace=0.5, hspace=0.5)
    f.set_size_inches(n_column, n_row)
    plt.show()

In [None]:
itr=iter(loader_train)

In [None]:
while labels[0].item()!=0:
    imgs,labels=next(itr)   
imgs=imgs.to(device)
with torch.no_grad():
    output=f(imgs)
a=output['layer3'][0].cpu().numpy().transpose(1,2,0)

In [None]:
n=0
seq=[]
while n!=32:
    imgs,labels=next(itr)   
    if labels[0].item()==0:
        n+=1
        imgs=imgs.to(device)
        with torch.no_grad():
            output=f(imgs)
            a=output['layer3'][0].cpu().numpy().transpose(1,2,0)
            b=a.sum(axis=2)
            seq.append(b)

In [None]:

for b in seq:
    plt.imshow(b)
    plt.show()

In [None]:
plot_activations(a)

In [None]:
img=imgs[0].cpu().numpy().transpose(1,2,0)
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])
plt.figure(figsize=(1.5,1.5))
#plt.imshow(rgb2gray(img),cmap='gray_r')
plt.imshow(img)   

In [None]:
from matplotlib import colormaps
list(colormaps)