In [5]:
import torch
import torch.nn as nn

class IdentityBlock(nn.Module):
    # just add the input as is at the end of forward pass
    def __init__(self, input_channels,output_channels,stride=1,downsample=None):
        super(IdentityBlock,self).__init__()
        padding =1
        self.block = torch.nn.Sequential(
            nn.Conv2d(input_channels,out_channels=output_channels,kernel_size=3,stride=stride,padding=padding),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=padding),
            nn.BatchNorm2d(output_channels),
        )
        self.downsample = downsample
        self.relu = nn.ReLU()

    def forward(self,x):
        residue = x
        out = self.block(x)
        if self.downsample :
            residue = self.downsample(x)
        # print("out shape :"+ str(out.shape))
        # print("residue shape :"+str(residue.shape))
        out += residue
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1,
    # padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
    def __init__(self,layers,block,num_of_classes):
        super(ResNet,self).__init__()
        self.inplanes = 64
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3,self.inplanes,7,2,3),
            nn.BatchNorm2d(self.inplanes),
            nn.ReLU()
        )
        self.initial_pool = nn.MaxPool2d(3,2,1)
        self.resnet_layer1 = self.layer_with_blocks(block,layers[0],64,stride=1)
        self.resnet_layer2 = self.layer_with_blocks(block,layers[1],128,stride=2)
        self.resnet_layer3 = self.layer_with_blocks(block,layers[2],256,stride=2)
        self.resnet_layer4 = self.layer_with_blocks(block,layers[3],512,stride=2)
        self.avg_pool = nn.AvgPool2d(7,stride=1)
        self.final = nn.Linear(512,num_of_classes)

    def layer_with_blocks(self,block,number_of_blocks,channels,stride=1):
        downsample = None
        if self.inplanes != channels or  stride != 1:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,channels,kernel_size=1,stride=stride),
                nn.BatchNorm2d(channels)
            )
        layers = []
        layers.append(block(self.inplanes,channels,stride,downsample))
        self.inplanes = channels
        for i in range(1,number_of_blocks):
            layers.append(block(self.inplanes,channels))

        return nn.Sequential(*layers)

    def forward(self,x):
        out = self.initial_conv(x)
        out = self.initial_pool(out)
        out = self.resnet_layer1(out)
        # print("out.shape after layer 1:"+str(out.shape))
        out = self.resnet_layer2(out)
        # print("out.shape after layer 2:"+str(out.shape))
        out = self.resnet_layer3(out)
        # print("out.shape after layer 3:"+str(out.shape))
        out = self.resnet_layer4(out)
        # print("out.shape after layer 4:"+str(out.shape))
        out = self.avg_pool(out)
        # print("out.shape avg pool :"+str(out.shape))
        out = out.view(out.size(0),-1)
        out = self.final(out)
        return out


In [None]:
# import torch
import torchvision
# import torch.nn as nn
import gc
# import resnet

# prepare data
batch_size = 128
num_of_classes = 10

# train_dataset = torchvision.datasets.CIFAR10(root="./data",train=True,download=True)
# test_dataset = torchvision.datasets.CIFAR10(root="./data",train=False,download=True)
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((224,224)),torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(root="./data",train=True,download=True,transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data",train=False,download=True,transform=transform)

train_data_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_data_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
train_data = iter(train_data_loader)
test_data = iter(test_data_loader)
# make model
model = ResNet([2,2,3,2],IdentityBlock,num_of_classes=num_of_classes)

# make loss and optimizer
loss = nn.CrossEntropyLoss()

epochs = 20
learning_rate = 0.001
optimizer = torch.optim.SGD(model.parameters(),learning_rate,weight_decay=0.001,momentum=0.9)

# train loop
for epoch in range(epochs):
    l = None
    for i, (batch,batch_labels) in enumerate(train_data):
        predictions = model(batch)
        l = loss(predictions,batch_labels)
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        del predictions
        gc.collect()

    # if epoch%10 == 0:
        # print loss
        print(f'for epoch : {epoch} and iteration:{i} the loss : {l.item()}')


with torch.no_grad():
    correct = 0
    total = 0
    for i, (batch,batch_labels) in enumerate(test_data):
        predictions = model(batch)
        total += batch_labels.shape[0]
        _, predicted = torch.max(predictions.data, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

    print(" accuracy : "+ str(correct/total))

Files already downloaded and verified
Files already downloaded and verified
for epoch : 0 and iteration:0 the loss : 2.410749912261963
for epoch : 0 and iteration:1 the loss : 2.311783790588379
for epoch : 0 and iteration:2 the loss : 2.2927086353302
for epoch : 0 and iteration:3 the loss : 2.339174747467041
for epoch : 0 and iteration:4 the loss : 2.330965757369995
for epoch : 0 and iteration:5 the loss : 2.2805933952331543
for epoch : 0 and iteration:6 the loss : 2.3229475021362305
for epoch : 0 and iteration:7 the loss : 2.2620816230773926
for epoch : 0 and iteration:8 the loss : 2.2387170791625977
for epoch : 0 and iteration:9 the loss : 2.1961588859558105
for epoch : 0 and iteration:10 the loss : 2.2143776416778564
for epoch : 0 and iteration:11 the loss : 2.214686870574951
for epoch : 0 and iteration:12 the loss : 2.191342353820801
for epoch : 0 and iteration:13 the loss : 2.1476385593414307
for epoch : 0 and iteration:14 the loss : 2.20194673538208
for epoch : 0 and iteration:15