In [1]:
# python libraries
import os
import sys
import dataclasses
from datetime import datetime
from pathlib import Path
from operator import methodcaller
from collections import OrderedDict
from dataclasses import dataclass
from typing import (
    List,
    Tuple,
    Dict,
    Any,
    Mapping,
    Callable
)
from enum import Enum
# adding the path
if not str(Path(os.getcwd()).parent) in sys.path:
    sys.path.append(str(Path(os.getcwd()).parent))

# numpy
import numpy as np

# torch
import torch
from torch import (
    nn,
    Tensor
    )
from torch.nn import functional as F


try:
    from torchmetrics import Accuracy
except:
    print(f"[INFO] Installing the torchmetrics")
    %pip install torchmetrics
    from torchmetrics import Accuracy

try:
    import torchinfo
except:
    print(f"[INFO] Installing the torchinfo")
    %pip install torchinfo
    import torchinfo

# helper function
try:
    import my_helper as helper
except:
    print("[INFO] Downloading the helper function from github")
    import requests
    response = requests.get("https://raw.githubusercontent.com/Lashi0812/PyTorch2/master/my_helper.py")
    with open("my_helper.py" ,"wb") as f:
        f.write(response.content)
    import my_helper as helper


## Connect Persistence memory
try :
    from google.colab import drive

    # Paths
    DRIVE_PATH = Path("/content/drive")
    MODEL_SAVE_PATH = Path("/content/drive/Othercomputers/My PC/drive/models")

    # mount drive
    drive.mount(str(DRIVE_PATH))
except:
    MODEL_SAVE_PATH = Path(os.getcwd())/"models"
    
device = "cuda" if torch.cuda.is_available() else "cpu"

# Network Design


## ResNeXt Block


In [2]:
class ResNeXtBlock(nn.Module):
    def __init__(
        self, num_channels, groups, bot_mul, use_conv1x1=False, stride=1
    ) -> None:
        super().__init__()
        bot_channels = int(round(num_channels * bot_mul))
        self.conv1 = nn.LazyConv2d(bot_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.LazyConv2d(
            bot_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=bot_channels // groups,
        )
        self.conv3 = nn.LazyConv2d(num_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.LazyBatchNorm2d()
        self.bn2 = nn.LazyBatchNorm2d()
        self.bn3 = nn.LazyBatchNorm2d()
        if use_conv1x1:
            self.conv4 = nn.LazyConv2d(
                num_channels, kernel_size=1, stride=stride, padding=0
            )
            self.bn4 = nn.LazyBatchNorm2d()
        else:
            self.conv4 = None

    def forward(self, x: Tensor) -> Tensor:
        # passing through bottleneck ie project to lower dimension
        y = F.relu(self.bn1(self.conv1(x)))
        # do the conv operation in low dimension and gather feature
        y = F.relu(self.bn2(self.conv2(y)))
        # projecting to original dimension
        y = self.bn3(self.conv3(y))
        if self.conv4:
            x = self.bn4(self.conv4(x))

        return F.relu(y + x)

# Any Network


In [None]:
class AnyNet(helper.Classifier):
    def __init__(self,arch:Tuple[Tuple[int,int,int,int]],stem_channels:int,lr:float=0.1,num_classes:int=10) -> None:
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.arch = arch
        self.stem_channels = stem_channels
        
        self.net = nn.Sequential(self.stem(self.stem_channels))
        
        for i,s in enumerate(arch):
            self.net.add_module(f'stage{i+1}',self.stage(*s))
        
        self.net.add_module("head",nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),
            nn.LazyLinear(self.num_classes)
        ))
        
    
    def stem(self,num_channels:int)->nn.Module:
        return nn.Sequential(
            nn.LazyConv2d(num_channels,kernel_size=7,stride=2,padding=3),
            nn.LazyBatchNorm2d(),nn.ReLU()
        )
    
    def stage(self,num_channels,depth,bot_mul,groups)->nn.Module:
        blk = []
        for i in range(depth):
            if i == 0:
                blk.append(ResNeXtBlock(num_channels,groups=groups,bot_mul=bot_mul,
                                        use_conv1x1=True,stride=2))
            else:
                blk.append(ResNeXtBlock(num_channels,groups=groups,bot_mul=bot_mul))
        return nn.Sequential(*blk)