In [16]:
import numpy as np
import  torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch import nn
import torch.optim as optim
import tensorflow as tf

In [17]:
class Block(nn.Module):
  expansion = 1
  def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
    super(Block, self).__init__()

    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
    self.batch_norm1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
    self.batch_norm2 = nn.BatchNorm2d(out_channels)

    self.i_downsample = i_downsample
    self.stride = stride
    self.relu = nn.ReLU()

  def forward(self, x):
    identity = x.clone()
    x = self.conv1(x)
    x = self.batch_norm2(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.batch_norm2(x)
    if self.i_downsample is not None:
        identity = self.i_downsample(identity)
    print(x.shape)
    print(identity.shape)
    x += identity
    x = self.relu(x)
    return x


class ResNet_vision(nn.Module):
  def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
    super(ResNet_vision, self).__init__()
    self.in_channels = 64
    
    self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    self.batch_norm1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()
    self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)
    
    self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
    self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
    self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
    self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
    
    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(512*ResBlock.expansion, num_classes)
      
  def forward(self, x):
    x = self.relu(self.batch_norm1(self.conv1(x)))
    x = self.max_pool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    
    x = self.avgpool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc(x)
    
    return x
      
  def _make_layer(self, ResBlock, blocks, planes, stride=1):
    ii_downsample = None
    layers = []
    
    if stride != 1 or self.in_channels != planes*ResBlock.expansion:
        ii_downsample = nn.Sequential(
            nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
            nn.BatchNorm2d(planes*ResBlock.expansion)
        )
        
    layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
    self.in_channels = planes*ResBlock.expansion
    
    for i in range(blocks-1):
        layers.append(ResBlock(self.in_channels, planes))
        
    return nn.Sequential(*layers)


In [18]:
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
# Training the model
def train2d(model, train_loader, loss_func, optimizer):
      batch_loss = 0.0
      correct = 0
      total = 0
      for data in train_loader:
          inputs, labels = data
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = loss_func(outputs, labels)
          loss.backward()
          optimizer.step()

          batch_loss += loss.item()
          predicted_val, predicted_indices = torch.max(outputs.data, 1)
          total += labels.size(0)
          correct += (predicted_indices == labels).sum().item()

      epoch_loss = batch_loss / total
      epoch_acc = correct / total
      return epoch_loss, epoch_acc

def validate2d(model, loader, loss_func):
    model.eval()
    batch_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            batch_loss += loss.item() * inputs.size(0)
            predicted_val, predicted_indices = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted_indices == labels).sum().item()

    epoch_loss = batch_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [20]:
# Initialize DataLoader for the test data
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)


num_epochs = 10
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []

num_blocks = [3, 4, 6, 3]  
image_channel = 3 
num_classes = 10


model_2d = ResNet_vision(Block, num_blocks,num_classes=num_classes)

loss_func2d = nn.CrossEntropyLoss()
optimizer2d = optim.Adam(model_2d.parameters(), lr=0.0003)

# Training and validation loops
for epoch in range(num_epochs):
    train_loader = DataLoader(training_data, batch_size=128, shuffle=True)
    train_loss, train_acc = train2d(model_2d, train_loader, loss_func2d, optimizer2d)
    val_loss, val_acc = validate2d(model_2d, test_loader, loss_func2d)

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')


torch.Size([128, 64, 8, 8])
torch.Size([128, 64, 8, 8])
torch.Size([128, 64, 8, 8])
torch.Size([128, 64, 8, 8])
torch.Size([128, 64, 8, 8])
torch.Size([128, 64, 8, 8])
torch.Size([128, 128, 2, 2])
torch.Size([128, 128, 4, 4])


RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 3

In [None]:
epochs = range(1, num_epochs + 1)
plt.plot(epochs, train_accuracies, 'b', label='Train Accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation Accuracy')
plt.title('Train and Validation Accuracies')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()