In [None]:
%load_ext watermark
%watermark -a 'NavinKumarMNK' -v -p torch

In [3]:
# Import the required modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchvision import models
import wandb
import torchmetrics


In [4]:
#model = models.efficientnet_v2_m(include_top=False, weights='EfficientNet_V2_M_Weights.DEFAULT')

#model = models.efficientnet_b3(include_top=False, weights='EfficientNet_B3_Weights.DEFAULT')
# remove last layer
#model.classifier = nn.Identity()

# Add the parent directory to the path
import sys
import os
sys.path.append(os.path.abspath(os.path.join('../')))

from models.EfficientNetv2.EncoderCoAtNet import EncoderCoAtNet
model = EncoderCoAtNet(num_blocks=[2, 2, 6, 14, 2])
model.summarize()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  exec(code_obj, self.user_global_ns, self.user_ns)


  | Name  | Type    | Params | In sizes         | Out sizes
-----------------------------------------------------------------
0 | model | CoAtNet | 55.1 M | [1, 3, 256, 256] | [1, 1024]
-----------------------------------------------------------------
55.1 M    Trainable params
0         Non-trainable params
55.1 M    Total params
220.272   Total estimated model params size (MB)

In [None]:
torch.save(model, '../weights/EfficientNetv2Encoder.pt')

In [None]:
model_1 = torch.load('../weights/EfficientNetv2Encoder.pt')

In [9]:
import torch.nn as nn

class SEAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction, in_channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Initial representation
        self.fc = nn.Linear(1024, 4*4*1024)
        self.bn1d = nn.BatchNorm1d(4*4*1024)
        self.gelu = nn.GELU()

        # Decoder layers
        self.conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn1 = nn.BatchNorm2d(512)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()

        self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn4 = nn.BatchNorm2d(64)
        self.relu4 = nn.ReLU()

        self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn5 = nn.BatchNorm2d(32)
        self.relu5 = nn.ReLU()

        self.conv6 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, output_padding=0)
        self.bn6 = nn.BatchNorm2d(16)
        self.relu6 = nn.ReLU()

        # Residual blocks with SE attention
        self.res2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.Sigmoid(),
            SEAttention(64),
            nn.ReLU()
        )

        self.res1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.Sigmoid(),
            SEAttention(256),
            nn.ReLU()
        )

        self.dropout = nn.Dropout(0.25)
        
        self.conv7 = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        x = self.bn1d(x)
        x = self.dropout(x)
        x = self.gelu(x)
        x = x.view(-1, 1024, 4, 4)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.relu2(x)

        x = self.res1(x) + x


        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.dropout(x)
        x = self.relu4(x)

        x = self.res2(x) + x

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.relu6(x)

        x = self.conv7(x)
        x = self.tanh(x)

        return x


In [10]:
torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input = torch.randn(16, 3, 256, 256)
encoder = model
decoder = Decoder()

feature = encoder(input)
print(feature.shape)
output = decoder(feature)
print(output.shape)
# differnece between input and output

model2 = decoder
total_params = sum(p.numel() for p in model2.parameters())
print(total_params)


torch.Size([16, 1024])


In [None]:
torch.save(model.state_dict(), '../weights/EfficientNetv2DecoderLarge.pt', )

In [None]:
# classifier
class EfficientNetv2Classifier(nn.Module):
    def __init__(self, no_of_classes):
        super(EfficientNetv2Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, no_of_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        return self.model(x)


In [None]:
import sys
import os
if os.path.abspath('../') not in sys.path:
    sys.path.append(os.path.abspath('../'))
from utils import utils

In [None]:
classes = int(utils.config_parse('GENERAL')['no_of_classes'])
classes

In [None]:

model = EfficientNetv2Classifier(classes)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

In [None]:
torch.save(model.state_dict(), '../weights/EfficientNetv2Classifier.pt')