In [1]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import torch.nn as nn
import torch.optim as optim
from ray.train import Checkpoint, get_checkpoint
from ray.tune.schedulers import ASHAScheduler
from ray import tune
import tempfile
from ray import train
from pathlib import Path
import os
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torchvision
from tqdm import tqdm
%matplotlib inline

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, identity=None, stride=1, padding=1):
        super(Block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=out_channels, kernel_size=3, stride=stride, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU()

        self.identity = identity

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.identity is not None:
            identity = self.identity(identity)
            x += identity
        else:
            x += identity
            
        x = self.relu(x)
        return x

In [18]:
class ResNet34(nn.Module):
    def __init__(self, num_classes):
        super(ResNet34, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)    
        self.relu = nn.ReLU()
        # self.avgpool = nn.AvgPool2d()

        # conv2_x
        self.conv2_1 = self._make_layer(
            in_channels=64, out_channels=64, stride=1, padding=1)
        self.conv2_2 = self._make_layer(
            in_channels=64, out_channels=64, stride=1, padding=1)
        self.conv2_3 = self._make_layer(
            in_channels=64, out_channels=64, stride=1, padding=1)

        # conv3_x
        self.conv3_1 = self._make_layer(
            in_channels=64, out_channels=128, stride=2, padding=1, identity=self._make_identity(in_channels=64, out_channels=128, stride=2))
        self.conv3_2 = self._make_layer(
            in_channels=128, out_channels=128, stride=1, padding=1)
        self.conv3_3 = self._make_layer(
            in_channels=128, out_channels=128, stride=1, padding=1)
        self.conv3_4 = self._make_layer(
            in_channels=128, out_channels=128, stride=1, padding=1)

        # conv4_x
        self.conv4_1 = self._make_layer(
            in_channels=128, out_channels=256, stride=2, padding=1, identity=self._make_identity(in_channels=128, out_channels=256, stride=2))
        self.conv4_2 = self._make_layer(
            in_channels=256, out_channels=256, stride=1, padding=1, identity=self._make_identity(in_channels=256, out_channels=256, stride=1))
        self.conv4_3 = self._make_layer(
            in_channels=256, out_channels=256, stride=1, padding=1, identity=self._make_identity(in_channels=256, out_channels=256, stride=1))
        self.conv4_4 = self._make_layer(
            in_channels=256, out_channels=256, stride=1, padding=1, identity=self._make_identity(in_channels=256, out_channels=256, stride=1))
        self.conv4_5 = self._make_layer(
            in_channels=256, out_channels=256, stride=1, padding=1, identity=self._make_identity(in_channels=256, out_channels=256, stride=1))
        self.conv4_6 = self._make_layer(
            in_channels=256, out_channels=256, stride=1, padding=1, identity=self._make_identity(in_channels=256, out_channels=256, stride=1))
        
        # conv5_x
        self.conv5_1 = self._make_layer(
            in_channels=256, out_channels=512, stride=2, padding=1, identity=self._make_identity(in_channels=256, out_channels=512, stride=2))
        self.conv5_2 = self._make_layer(
            in_channels=512, out_channels=512, stride=1, padding=1)
        self.conv5_3 = self._make_layer(
            in_channels=512, out_channels=512, stride=1, padding=1)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_features= 512 * 1 * 1, out_features=num_classes)


    def _make_layer(self, in_channels, out_channels, stride, padding, identity=None):
        block = Block(in_channels=in_channels, out_channels=out_channels, identity=identity, stride=stride, padding=padding)
        return block
    
    
    def _make_identity(self, in_channels, out_channels, stride=1):
        identity = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride),
                                 nn.BatchNorm2d(out_channels))
        return identity

    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # conv2_x
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv2_3(x)

        # conv3_x
        x = self.conv3_1(x)
        x = self.conv3_2(x)
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        # conv4_x
        x = self.conv4_1(x)
        x = self.conv4_2(x)
        x = self.conv4_3(x)
        x = self.conv4_4(x)
        x = self.conv4_5(x)
        x = self.conv4_6(x)

        # conv5_x
        x = self.conv5_1(x)
        x = self.conv5_2(x)
        x = self.conv5_3(x)

        x = self.avgpool(x)
        x = self.fc(x.view(x.shape[0], -1))

        return x
    

model = ResNet34(num_classes=10)
x = torch.randn(20, 3, 224, 224)
model(x).shape

torch.Size([20, 10])

In [19]:
batch_size=32
lr = 0.0001
num_epoch = 10

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(num_epoch):

    # traing loop


    #validation loop
    with torch.no_grad():
        pass