In [1]:
# python libraries
import os
import sys
from pathlib import Path
from operator import methodcaller
from collections import OrderedDict
import dataclasses
from dataclasses import dataclass
from typing import (
    List,
    Tuple,
    Dict,
    Any,
    Mapping
)
from enum import Enum
# adding the path
if not str(Path(os.getcwd()).parent) in sys.path:
    sys.path.append(str(Path(os.getcwd()).parent))

# torch
import torch
import torchvision
from torch import nn
from torchvision import (
    transforms,
    datasets
    )
from torch.utils.data import DataLoader

try:
    from torchmetrics import Accuracy
except:
    print(f"[INFO] Installing the torchmetrics")
    %pip install torchmetrics
    from torchmetrics import Accuracy

# helper function
try:
    import my_helper as helper
except:
    print("[INFO] Downloading the helper function from github")
    import requests
    response = requests.get("https://raw.githubusercontent.com/Lashi0812/PyTorch2/master/my_helper.py")
    with open("my_helper.py" ,"wb") as f:
        f.write(response.content)
    import my_helper as helper

# LeNet

In [None]:
class LeNet(helper.Classifier):
    def __init__(self,lr,num_classes) -> None:
        super().__init__()
        self.lr = lr
        self.num_classes = num_classes
        self.net = nn.Sequential(OrderedDict([
            ("Conv2d_1",nn.LazyConv2d(out_channels=6,kernel_size=(5,5),padding=2)),
            ("Sigmoid_Conv2d_1",nn.Sigmoid()),
            ("AvgPool2d_1",nn.AvgPool2d(kernel_size=2,stride=2)),
            ("Conv2d_2",nn.LazyConv2d(out_channels=16,kernel_size=5)), #! There is no padding done 2nd conv layer
            ("Sigmoid_Conv2d_2",nn.Sigmoid()),
            ("AvgPool2d_1",nn.AvgPool2d(kernel_size=2,stride=2)),
            ("Flatten",nn.Flatten()),
            ("Linear_1",nn.LazyLinear(out_features=120)),
            ("Sigmoid_Linear_1",nn.Sigmoid()),
            ("Linear_2",nn.LazyLinear(out_features=84)),
            ("Sigmoid_Linear_2",nn.Sigmoid()),
            ("Linear_2",nn.LazyLinear(out_features=self.num_classes))
        ]))

In [None]:
lenet = LeNet(lr=0.01,num_classes=10)

In [None]:
lenet.layer_summary(input_shape=(1,1,28,28))

In [None]:
def init_cnn(module):
    if isinstance(module,(nn.Linear,nn.Conv2d)):
        nn.init.xavier_uniform_(module.weight)

# Data

In [None]:
class FashionMNIST(helper.DataModule):
    def __init__(self, batch_size: int = 64, resize=(28, 28)) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.resize = resize

        transform = transforms.Compose(
            [transforms.Resize(resize), transforms.ToTensor()]
        )

        self.train = datasets.FashionMNIST(
            root=self.root, train=True, transform=transform,download=True
        )
        self.val = datasets.FashionMNIST(
            root=self.root, train=False, transform=transform,download=True
        )
        self.classes = self.train.classes
        self.class_to_idx = self.train.class_to_idx

    def text_labels(self, indices: List):
        return [self.classes[a] for a in indices]

    def get_dataloader(self, train: bool):
        data = self.train if train else self.val
        return DataLoader(
            dataset=data,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=train,
        )
    
    def visualize(self,batch:Tuple,num_rows=1,num_cols=8):
        X,y = batch
        labels = self.text_labels(y)
        helper.show_images(X.squeeze(1),num_rows=num_rows,num_cols=num_cols,titles=labels)

# Training the model

In [None]:
data = FashionMNIST(batch_size=128)
model = LeNet(lr=0.1,num_classes=len(data.classes))
model.apply_init([next(iter(data.get_dataloader(True)))[0]],init_cnn)


In [None]:
trainer = helper.Trainer(max_epochs=3)
trainer.fit(model,data)

# Experiments


## Exp 1
1. Replace the average pooling with max-pooling.
2. Replace the softmax layer with ReLU.

In [None]:
list(model.net.named_children())

In [None]:
def modify_model(model, changes: Dict):
    layers = nn.Sequential()
    for name, layer in model.net.named_children():
        layer_name = layer.__class__.__name__
        if layer_name in changes.keys():
            layers.add_module(
                name=name,
                module=methodcaller(changes[layer_name][0], **changes[layer_name][1])(nn),
            )
        else:
            layers.add_module(name=name, module=layer)
    return layers

In [None]:
change_dict = {"AvgPool2d":("MaxPool2d",{"kernel_size":2,"stride":2}),
               "Sigmoid":("ReLU",{})}

In [None]:
model.net = modify_model(model,change_dict)

In [None]:
model.net

In [None]:
trainer = helper.Trainer(max_epochs=3)
trainer.fit(model,data)

In [None]:
model.net

## Exp2 - Change the kernel size

In [158]:
@dataclass(slots=True)
class A:
    args:List
    kwargs:Any

    def __init__(self,*args,**kwargs) -> None:
        self.args = args
        self.kwargs = kwargs

In [159]:
a = A(Resize={"size":(28,28)},ToTensor={})

In [160]:
a.kwargs

{'Resize': {'size': (28, 28)}, 'ToTensor': {}}

In [5]:
b = A(**dict(Resize={"size":(28,28)},ToTensor={}))

In [6]:
b.kwargs

{'Resize': {'size': (28, 28)}, 'ToTensor': {}}

In [7]:
b.args

()

In [8]:
@dataclass(init=True,slots=True)
class DataSetting:
    batch_size:int
    data:helper.DataLoader
    kwargs:Mapping[str,Dict]
    def __init__(self,*,batch_size:int,data:helper.DataModule,**kwargs) -> None:
        self.batch_size = batch_size
        self.data = data
        self.kwargs = kwargs
    
    def replace(self,**kwargs):
        for k,v in kwargs:
            pass
            
        
        

In [9]:
setting = DataSetting(batch_size=32,data=helper.FashionMNIST,Resize={"size":(28,28)},ToTensor={})

In [10]:
setting.batch_size

32

In [11]:
setting.kwargs

{'Resize': {'size': (28, 28)}, 'ToTensor': {}}

In [12]:
[print(field.name) for field in setting.__dataclass_fields__.values()]

batch_size
data
kwargs


[None, None, None]

In [13]:
dataclasses.asdict(dataclasses.replace(dataclasses.replace(setting)))

{'batch_size': 32,
 'data': my_helper.FashionMNIST,
 'kwargs': {'kwargs': {'kwargs': {'Resize': {'size': (28, 28)},
    'ToTensor': {}}}}}

In [14]:
dataclasses.asdict(setting)

{'batch_size': 32,
 'data': my_helper.FashionMNIST,
 'kwargs': {'Resize': {'size': (28, 28)}, 'ToTensor': {}}}

In [15]:
@dataclass(init=False)
class ArgHolder:
    args: List[Any]
    kwargs: Mapping[Any, Any]

    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
    
    def replace(self, /, **changes):
        return dataclasses.replace(self, **changes) 

a = ArgHolder(1, 2, three=3)

In [16]:
dataclasses.replace(dataclasses.replace(a),a=10)

ArgHolder(args=(), kwargs={'a': 10, 'args': (), 'kwargs': {'args': (1, 2), 'kwargs': {'three': 3}}})

In [17]:
a.replace(args=10)

ArgHolder(args=(), kwargs={'args': 10, 'kwargs': {'three': 3}})

In [18]:
import inspect

In [19]:
inspect.getfullargspec(transforms.Resize)

FullArgSpec(args=['self', 'size', 'interpolation', 'max_size', 'antialias'], varargs=None, varkw=None, defaults=(<InterpolationMode.BILINEAR: 'bilinear'>, None, None), kwonlyargs=[], kwonlydefaults=None, annotations={})

In [20]:
dir(inspect.signature(transforms.Resize))

['__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '_bind',
 '_bound_arguments_cls',
 '_hash_basis',
 '_parameter_cls',
 '_parameters',
 '_return_annotation',
 'bind',
 'bind_partial',
 'empty',
 'from_builtin',
 'from_callable',
 'from_function',
 'parameters',
 'replace',
 'return_annotation']

In [21]:
dict(inspect.signature(transforms.Resize).parameters).values()

dict_values([<Parameter "size">, <Parameter "interpolation=<InterpolationMode.BILINEAR: 'bilinear'>">, <Parameter "max_size=None">, <Parameter "antialias=None">])

In [22]:
inspect.signature(transforms.Resize)

<Signature (size, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, max_size=None, antialias=None)>

In [23]:
dir(inspect.signature(transforms.Resize))

['__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '_bind',
 '_bound_arguments_cls',
 '_hash_basis',
 '_parameter_cls',
 '_parameters',
 '_return_annotation',
 'bind',
 'bind_partial',
 'empty',
 'from_builtin',
 'from_callable',
 'from_function',
 'parameters',
 'replace',
 'return_annotation']

In [24]:
for params in inspect.signature(transforms.Resize).parameters.values():
    print(params.name,params.default,type(params.default),isinstance(params.default,Enum),params.annotation)
    if isinstance(params.default,Enum):
        print(params.default.value)


size <class 'inspect._empty'> <class 'type'> False <class 'inspect._empty'>
interpolation InterpolationMode.BILINEAR <enum 'InterpolationMode'> True <class 'inspect._empty'>
bilinear
max_size None <class 'NoneType'> False <class 'inspect._empty'>
antialias None <class 'NoneType'> False <class 'inspect._empty'>


In [25]:
inspect.signature(transforms.AutoAugment).parameters["interpolation"]

<Parameter "interpolation: torchvision.transforms.functional.InterpolationMode = <InterpolationMode.NEAREST: 'nearest'>">

In [26]:
dir(params)

['KEYWORD_ONLY',
 'POSITIONAL_ONLY',
 'POSITIONAL_OR_KEYWORD',
 'VAR_KEYWORD',
 'VAR_POSITIONAL',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '_annotation',
 '_default',
 '_kind',
 '_name',
 'annotation',
 'default',
 'empty',
 'kind',
 'name',
 'replace']

In [27]:
type(inspect.getfullargspec(transforms.Resize).defaults[0])

<enum 'InterpolationMode'>

In [28]:
@dataclass
class Test:
    name:str

In [29]:
t = Test("sdf")

In [30]:
dataclasses.asdict(t)

{'name': 'sdf'}

In [31]:
t.age = 10

In [32]:
dataclasses.asdict(t)

{'name': 'sdf'}

In [33]:
transforms.transforms.__all__

['Compose',
 'ToTensor',
 'PILToTensor',
 'ConvertImageDtype',
 'ToPILImage',
 'Normalize',
 'Resize',
 'CenterCrop',
 'Pad',
 'Lambda',
 'RandomApply',
 'RandomChoice',
 'RandomOrder',
 'RandomCrop',
 'RandomHorizontalFlip',
 'RandomVerticalFlip',
 'RandomResizedCrop',
 'FiveCrop',
 'TenCrop',
 'LinearTransformation',
 'ColorJitter',
 'RandomRotation',
 'RandomAffine',
 'Grayscale',
 'RandomGrayscale',
 'RandomPerspective',
 'RandomErasing',
 'GaussianBlur',
 'InterpolationMode',
 'RandomInvert',
 'RandomPosterize',
 'RandomSolarize',
 'RandomAdjustSharpness',
 'RandomAutocontrast',
 'RandomEqualize']

In [34]:
transforms.autoaugment.__all__

['AutoAugmentPolicy',
 'AutoAugment',
 'RandAugment',
 'TrivialAugmentWide',
 'AugMix']

In [35]:
for trans in transforms.transforms.__all__[:5]:
    if not issubclass(getattr(transforms,trans),Enum):
        sig = inspect.signature(getattr(transforms,trans))
        print(f"{trans} has total {len(sig.parameters.values())} paramaters")
        for params in sig.parameters.values():
            print(f"\tName  : {params.name}")
            print(f"\tdefault  : {params.default}")
    

Compose has total 1 paramaters
	Name  : transforms
	default  : <class 'inspect._empty'>
ToTensor has total 0 paramaters
PILToTensor has total 0 paramaters
ConvertImageDtype has total 1 paramaters
	Name  : dtype
	default  : <class 'inspect._empty'>
ToPILImage has total 1 paramaters
	Name  : mode
	default  : None


In [36]:
getattr(transforms,"AutoAugment")

torchvision.transforms.autoaugment.AutoAugment

In [37]:
sig = inspect.signature(getattr(transforms,"Resize"))
print(f"{trans} has total {len(sig.parameters.values())} parameters")
fields = []
for params in sig.parameters.values():
    if issubclass(type(params.default),Enum):
        default = params.default.value
    elif params.default is inspect._empty:
        default = dataclasses.MISSING
    else:
        default = params.default
        
    fields.append((params.name,Any,dataclasses.field(default=default)))

ToPILImage has total 4 parameters


In [38]:
fields

[('size',
  typing.Any,
  Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('interpolation',
  typing.Any,
  Field(name=None,type=None,default='bilinear',default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('max_size',
  typing.Any,
  Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('antialias',
  typing.Any,
  Field(name=None,type=N

In [39]:
issubclass(transforms.AutoAugmentPolicy,Enum)

True

In [40]:
fields

[('size',
  typing.Any,
  Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('interpolation',
  typing.Any,
  Field(name=None,type=None,default='bilinear',default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('max_size',
  typing.Any,
  Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,_field_type=None)),
 ('antialias',
  typing.Any,
  Field(name=None,type=N

In [41]:
C = dataclasses.make_dataclass("C",fields)

In [42]:
[t.name for t in dataclasses.fields(C) if t.default is dataclasses.MISSING]

['size']

In [43]:
c = C(10)

In [44]:
c.max_size

In [45]:
c.interpolation

'bilinear'

In [46]:
c.size

10

In [47]:
dataclasses.fields(c)

(Field(name='size',type=typing.Any,default=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='interpolation',type=typing.Any,default='bilinear',default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='max_size',type=typing.Any,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='antialias',type=typing.Any,default=None,default_factory=<dataclasses._MISSING_TYPE object at 0x000002826E0BD3C0>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD))

In [48]:
_original_create_fn = dataclasses._create_fn

def _new_create_fn(name, args, body, **kwargs):
    args_str = ', '.join(args)
    body_str = '\n'.join('  ' + l for l in body)
    print(f'def {name}({args_str}):\n{body_str}\n')
    return _original_create_fn(name, args, body, **kwargs)

# dataclasses._create_fn = _new_create_fn

In [49]:
C = dataclasses.make_dataclass("C",fields)

In [50]:
type(C)

type

In [51]:
c = C(10)

In [52]:
type(c)

types.C

In [53]:
type(c.size)

int

In [54]:
type(C)

type

In [55]:
type(C(1,1))

types.C

In [56]:
issubclass(type(c),C)

True

In [57]:
isinstance(c,C)

True

In [58]:
def create_fields(class_name):
    # getting the signature for the class
    sig = inspect.signature(getattr(transforms,class_name))
    fields = []
    for params in sig.parameters.values():
        # if issubclass(type(params.default),Enum):
        #     default = params.default.value
        if params.default is inspect._empty:
            default = dataclasses.MISSING
        else:
            default = params.default
        
        fields.append((params.name,Any,dataclasses.field(default=default)))
    return fields

In [59]:
auto = dataclasses.make_dataclass("Auto",create_fields("AutoAugment"))

In [60]:
transforms.AutoAugment(**dataclasses.asdict(auto()))

AutoAugment(policy=AutoAugmentPolicy.IMAGENET, fill=None)

In [61]:
trans_class = []
for trans_name in transforms.transforms.__all__ + transforms.autoaugment.__all__:
    trans_class.append((
        trans_name,
        trans_name,
        dataclasses.field(
            default=dataclasses.make_dataclass(
                trans_name, create_fields(trans_name)
            )
        ),
    ))


In [66]:
AllTransformer = dataclasses.make_dataclass("AllTransformer",trans_class,repr=False,eq=False,kw_only=True,slots=True)

In [67]:
[f.name for f in dataclasses.fields(AllTransformer().TrivialAugmentWide)]

['num_magnitude_bins', 'interpolation', 'fill']

In [68]:
dataclasses.replace(AllTransformer().Resize(20),size=10)

Resize(size=10, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, max_size=None, antialias=None)

In [None]:
dataclasses.asdict(AllTransformer().AutoAugment())

{'policy': 'imagenet', 'interpolation': 'nearest', 'fill': None}

In [133]:
class TransformSetting:
    @staticmethod
    def _get_all_transformer():
        trans_class = []
        for trans_name in (
            transforms.transforms.__all__ + transforms.autoaugment.__all__
        ):
            trans_class.append(
                (
                    trans_name,
                    trans_name,
                    dataclasses.field(
                        default=dataclasses.make_dataclass(
                            trans_name, create_fields(trans_name)
                        )
                    ),
                )
            )
        return dataclasses.make_dataclass(
            "AllTransformer",
            trans_class,
            repr=False,
            eq=False,
            kw_only=True,
            slots=True,
        )

    @staticmethod
    def _create_fields(class_name):
        # getting the signature for the class
        sig = inspect.signature(getattr(transforms, class_name))
        fields = []
        for params in sig.parameters.values():
            # if issubclass(type(params.default),Enum):
            #     default = params.default.value
            if params.default is inspect._empty:
                default = dataclasses.MISSING
            else:
                default = params.default

            fields.append((params.name, Any, dataclasses.field(default=default)))
        return fields

    all_transformer = _get_all_transformer.__func__()

    def __init__(self, **kwargs) -> None:
        self.transform_name = []
        self.transform_arg = []
        for k, v in kwargs.items():
            if k in dir(self.__class__.all_transformer):
                self.transform_name.append(k)
                self.transform_arg.append(
                    getattr(self.__class__.all_transformer(), k)(**v)
                )


In [138]:
base = TransformSetting(Resize={"size":10},ToTensor={})

In [136]:
dataclasses.replace(base.transform_arg[0],size=20)

Resize(size=20, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, max_size=None, antialias=None)

In [139]:
base.transform_name

['Resize', 'ToTensor']

In [140]:
base.transform_arg

[Resize(size=10, interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, max_size=None, antialias=None),
 ToTensor()]

In [151]:
@dataclass
class BaseExperiment:
    transform:TransformSetting

In [152]:
base = BaseExperiment(TransformSetting(Resize={"size":(28,28)},ToTensor={}))

In [154]:
base.transform.transform_arg

[Resize(size=(28, 28), interpolation=<InterpolationMode.BILINEAR: 'bilinear'>, max_size=None, antialias=None),
 ToTensor()]

In [145]:
@dataclass
class Experiment:
    base_setting:BaseExperiment
    transform_changes:Dict[str,Dict]

    def modify_base_transform(self):
        modified_list = []
        for trans in self.base_setting.transform.transform_name:
            if self.transform_changes.get(trans) is None:
                modified_list.append(trans)
        

In [157]:
if dict(hi="sdf").get("hi") is None:
    print("sdf")