<a href="https://colab.research.google.com/github/TIMEdilation584/JP_Loksatta_moving_hearts/blob/master/ResNet_RS_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install dependencies & get ImageNette data
!pip install timm
!pip install wandb 
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz
!tar -xf /content/imagenette2-160.tgz
!pip install accelerate

In [2]:
# make required imports
import torch 
import timm 
import wandb 
import torchvision
import pandas as pd
import torch.nn as nn

from torchvision import transforms
from accelerate import Accelerator
from PIL import Image 
from matplotlib import pyplot as plt
from pathlib import Path 
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from dataclasses import dataclass
from timm.models.registry import register_model
from timm.models.helpers import build_model_with_cfg
from timm.models.resnet import Bottleneck, _create_resnet, default_cfgs, _cfg, make_blocks, create_classifier

# ResNet-RS Model Implementation

In [3]:
# add default configs 
default_cfgs['resnetrs50'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs101'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs152'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs200'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs270'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs350'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs420'] = _cfg(interpolation='bicubic', first_conv='conv1.0')
default_cfgs['resnetrs50']

{'classifier': 'fc',
 'crop_pct': 0.875,
 'first_conv': 'conv1.0',
 'input_size': (3, 224, 224),
 'interpolation': 'bicubic',
 'mean': (0.485, 0.456, 0.406),
 'num_classes': 1000,
 'pool_size': (7, 7),
 'std': (0.229, 0.224, 0.225),
 'url': ''}

In [4]:
# update the ResNet class implementation in TIMM to include changes for ResNet-RS models
# refer to blog for more details https://wandb.ai/wandb_fc/pytorch-image-models/reports/Revisiting-ResNets-Improved-Training-and-Scaling-Strategies--Vmlldzo2NDE3NTM
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, in_chans=3,
                 cardinality=1, base_width=64, stem_width=64, stem_type='',
                 output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
                 act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
                 drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, replace_stem_max_pool=False):
        block_args = block_args or dict()
        assert output_stride in (8, 16, 32)
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.replace_stem_max_pool = replace_stem_max_pool
        super(ResNet, self).__init__()

        # Stem
        deep_stem = 'deep' in stem_type
        inplanes = stem_width * 2 if deep_stem else 64
        if deep_stem:
            stem_chs = (stem_width, stem_width)
            if 'tiered' in stem_type:
                stem_chs = (3 * (stem_width // 4), stem_width)
            self.conv1 = nn.Sequential(*[
                nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
                norm_layer(stem_chs[0]),
                act_layer(inplace=True),
                nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
                norm_layer(stem_chs[1]),
                act_layer(inplace=True),
                nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
        else:
            self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(inplanes)
        self.act1 = act_layer(inplace=True)
        self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]

        # Stem Pooling
        if not self.replace_stem_max_pool:
            if aa_layer is not None:
                self.maxpool = nn.Sequential(*[
                    nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                    aa_layer(channels=inplanes, stride=2)])
            else:
                self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.maxpool = nn.Sequential(*[
                nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1), 
                nn.BatchNorm2d(inplanes), 
                nn.ReLU()
            ])

        # Feature Blocks
        channels = [64, 128, 256, 512]
        stage_modules, stage_feature_info = make_blocks(
            block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
            output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
            down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
            drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
        for stage in stage_modules:
            self.add_module(*stage)  # layer1, layer2, etc
        self.feature_info.extend(stage_feature_info)

        # Head (Pooling and Classifier)
        self.num_features = 512 * block.expansion
        self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

        for n, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.)
                nn.init.constant_(m.bias, 0.)
        if zero_init_last_bn:
            for m in self.modules():
                if hasattr(m, 'zero_init_last_bn'):
                    m.zero_init_last_bn()

    def get_classifier(self):
        return self.fc

    def reset_classifier(self, num_classes, global_pool='avg'):
        self.num_classes = num_classes
        self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

    def forward_features(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.global_pool(x)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        x = self.fc(x)
        return x

In [5]:
def _create_resnet(variant, pretrained=False, **kwargs):
    return build_model_with_cfg(
        ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)

In [6]:
@register_model
def resnetrs50(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs50', pretrained, **model_args)


@register_model
def resnetrs101(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs101', pretrained, **model_args)


@register_model
def resnetrs152(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs152', pretrained, **model_args)


@register_model
def resnetrs200(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs200', pretrained, **model_args)


@register_model
def resnetrs270(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs270', pretrained, **model_args)



@register_model
def resnetrs350(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs350', pretrained, **model_args)


@register_model
def resnetrs420(pretrained=False, **kwargs):
    model_args = dict(
        block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
        avg_down=True,  block_args=dict(attn_layer='se'), **kwargs)
    return _create_resnet('resnetrs420', pretrained, **model_args)

In [7]:
# create ResNet-RS 50 model 
model_resnetrs50 = timm.create_model('resnetrs50')
# create Dummy inputs
x = torch.randn(1, 3, 224, 224)
# forward pass
model_resnetrs50(x).shape

torch.Size([1, 1000])

# Train on ImageNette

## Training Config

In [8]:
# Config for training 
IMG_SIZE = 160

Config = dict(
    DATA_DIR="/content/imagenette2-160",
    TRAIN_DATA_DIR="/content/imagenette2-160/train",
    TEST_DATA_DIR="/content/imagenette2-160/val",
    DEVICE="cuda",
    PRETRAINED=False,
    LR=1e-5,
    EPOCHS=5,
    IMG_SIZE=IMG_SIZE,
    BS=64,
    TRAIN_AUG=transforms.Compose(
        [
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.RandomErasing(0.2),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    ),
    TEST_AUG=transforms.Compose(
        [
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    ),
)

## Train and Eval loop

In [9]:
def train_fn(model, train_data_loader, optimizer, epoch, accelerator):
    model.train()
    fin_loss = 0.0
    tk = tqdm(train_data_loader, desc="Epoch" + " [TRAIN] " + str(epoch + 1))

    for t, data in enumerate(tk):
        optimizer.zero_grad()
        out = model(data[0])
        loss = nn.CrossEntropyLoss()(
            out, data[1]
            )
        accelerator.backward(loss)
        optimizer.step()

        fin_loss += loss.item()
        tk.set_postfix(
            {
                "loss": "%.6f" % float(fin_loss / (t + 1)),
                "LR": optimizer.param_groups[0]["lr"],
            }
        )
    return fin_loss/len(train_data_loader), optimizer.param_groups[0]["lr"]

In [10]:
def eval_fn(model, eval_data_loader, epoch):
    model.eval()
    fin_loss = 0.0
    tk = tqdm(eval_data_loader, desc="Epoch" + " [VALID] " + str(epoch + 1))

    with torch.no_grad():
        for t, data in enumerate(tk):
            out = model(data[0])
            loss = nn.CrossEntropyLoss()(
                out, data[1]
                )
            fin_loss += loss.item()
            tk.set_postfix({"loss": "%.6f" % float(fin_loss / (t + 1))})
        return fin_loss/len(eval_data_loader)

In [11]:
def train():
    accelerator = Accelerator()

    # wandb init
    wandb.init(config=Config, project='ImageNette', save_code=True, 
           job_type='train', tags=['resnetrs', 'imagenette'], 
           name=Config['MODEL'])    
    
    # train and eval datasets 
    train_dataset = torchvision.datasets.ImageFolder(
        Config['TRAIN_DATA_DIR'], 
        transform=Config['TRAIN_AUG']
        )
    eval_dataset = torchvision.datasets.ImageFolder(
        Config['TEST_DATA_DIR'], 
        transform=Config['TEST_AUG']
        )

    # train and eval dataloaders
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=Config["BS"]
    )
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=Config["BS"]
    )

    # model
    model = timm.create_model(
        Config['MODEL'], 
        pretrained=Config['PRETRAINED']
        )

    # optimizer    
    optimizer = torch.optim.Adam(
        model.parameters(), lr=Config["LR"]
    )

    # prepare for DDP
    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, eval_dataloader)    

    for epoch in range(Config["EPOCHS"]):
        avg_loss_train, lr = train_fn(
            model, train_dataloader, optimizer, epoch, accelerator)
        avg_loss_eval = eval_fn(
            model, eval_dataloader, epoch)
        wandb.log({'train_loss': avg_loss_train, 'eval_loss': avg_loss_eval, 'lr': lr})

## Train models

In [12]:
for model in ['resnetrs50', 'resnetrs101', 'resnet50']:
  Config['MODEL'] = model
  train()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch [TRAIN] 1: 100%|██████████| 148/148 [00:39<00:00,  3.72it/s, loss=6.758239, LR=1e-5]
Epoch [VALID] 1: 100%|██████████| 62/62 [00:08<00:00,  6.96it/s, loss=6.014195]
Epoch [TRAIN] 2: 100%|██████████| 148/148 [00:39<00:00,  3.76it/s, loss=5.299781, LR=1e-5]
Epoch [VALID] 2: 100%|██████████| 62/62 [00:08<00:00,  6.99it/s, loss=4.613536]
Epoch [TRAIN] 3: 100%|██████████| 148/148 [00:39<00:00,  3.75it/s, loss=4.017516, LR=1e-5]
Epoch [VALID] 3: 100%|██████████| 62/62 [00:09<00:00,  6.88it/s, loss=3.463946]
Epoch [TRAIN] 4: 100%|██████████| 148/148 [00:39<00:00,  3.74it/s, loss=3.183121, LR=1e-5]
Epoch [VALID] 4: 100%|██████████| 62/62 [00:09<00:00,  6.88it/s, loss=2.835897]
Epoch [TRAIN] 5: 100%|██████████| 148/148 [00:39<00:00,  3.74it/s, loss=2.781310, LR=1e-5]
Epoch [VALID] 5: 100%|██████████| 62/62 [00:08<00:00,  6.89it/s, loss=2.571084]


VBox(children=(Label(value=' 0.04MB of 0.04MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,2.78131
eval_loss,2.57108
lr,1e-05
_runtime,252.0
_timestamp,1622801295.0
_step,4.0


0,1
train_loss,█▅▃▂▁
eval_loss,█▅▃▂▁
lr,▁▁▁▁▁
_runtime,▁▃▄▆█
_timestamp,▁▃▄▆█
_step,▁▃▅▆█


Epoch [TRAIN] 1: 100%|██████████| 148/148 [00:53<00:00,  2.76it/s, loss=6.766204, LR=1e-5]
Epoch [VALID] 1: 100%|██████████| 62/62 [00:10<00:00,  5.94it/s, loss=5.989587]
Epoch [TRAIN] 2: 100%|██████████| 148/148 [00:53<00:00,  2.78it/s, loss=5.316126, LR=1e-5]
Epoch [VALID] 2: 100%|██████████| 62/62 [00:10<00:00,  5.97it/s, loss=4.587626]
Epoch [TRAIN] 3: 100%|██████████| 148/148 [00:53<00:00,  2.77it/s, loss=4.028769, LR=1e-5]
Epoch [VALID] 3: 100%|██████████| 62/62 [00:10<00:00,  5.95it/s, loss=3.459581]
Epoch [TRAIN] 4: 100%|██████████| 148/148 [00:53<00:00,  2.78it/s, loss=3.186336, LR=1e-5]
Epoch [VALID] 4: 100%|██████████| 62/62 [00:10<00:00,  6.00it/s, loss=2.842945]
Epoch [TRAIN] 5: 100%|██████████| 148/148 [00:53<00:00,  2.78it/s, loss=2.780900, LR=1e-5]
Epoch [VALID] 5: 100%|██████████| 62/62 [00:10<00:00,  5.95it/s, loss=2.572785]


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,2.7809
eval_loss,2.57279
lr,1e-05
_runtime,322.0
_timestamp,1622801623.0
_step,4.0


0,1
train_loss,█▅▃▂▁
eval_loss,█▅▃▂▁
lr,▁▁▁▁▁
_runtime,▁▃▅▆█
_timestamp,▁▃▅▆█
_step,▁▃▅▆█


Epoch [TRAIN] 1: 100%|██████████| 148/148 [00:34<00:00,  4.25it/s, loss=6.829496, LR=1e-5]
Epoch [VALID] 1: 100%|██████████| 62/62 [00:08<00:00,  7.33it/s, loss=6.042809]
Epoch [TRAIN] 2: 100%|██████████| 148/148 [00:34<00:00,  4.28it/s, loss=5.342644, LR=1e-5]
Epoch [VALID] 2: 100%|██████████| 62/62 [00:08<00:00,  7.32it/s, loss=4.629596]
Epoch [TRAIN] 3: 100%|██████████| 148/148 [00:34<00:00,  4.27it/s, loss=4.015556, LR=1e-5]
Epoch [VALID] 3: 100%|██████████| 62/62 [00:08<00:00,  7.29it/s, loss=3.520964]
Epoch [TRAIN] 4: 100%|██████████| 148/148 [00:34<00:00,  4.27it/s, loss=3.158181, LR=1e-5]
Epoch [VALID] 4: 100%|██████████| 62/62 [00:08<00:00,  7.31it/s, loss=2.907657]
Epoch [TRAIN] 5: 100%|██████████| 148/148 [00:34<00:00,  4.27it/s, loss=2.755575, LR=1e-5]
Epoch [VALID] 5: 100%|██████████| 62/62 [00:08<00:00,  7.22it/s, loss=2.623513]
