In [1]:
import torch
import random
import numpy as np

seed_value = 42

torch.manual_seed(seed_value)
random.seed(seed_value)
np.random.seed(seed_value)

# Ensure deterministic behavior in convolutional layers
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
import torch
import torch.nn as nn
from einops import rearrange

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.Mish()
    )


def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.Mish()
    )

def depthwise_conv(inp, oup, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.SiLU(),
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=16):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = nn.Hardswish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

class SE(nn.Module):
    def __init__(self, channel, reduction_ratio =16):
        super(SE, self).__init__()
        ### Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        ### Fully Connected Multi-Layer Perceptron (FC-MLP)
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction_ratio, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction_ratio, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.gap(x).view(b, c)
        y = self.mlp(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        self.conv = nn.Sequential(
            # pw
            nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.Hardswish(),
            # dw
            nn.Conv2d(hidden_dim, hidden_dim, 5, stride, 2, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.Hardswish(),

            CoordAtt(hidden_dim, hidden_dim),
            # SE(hidden_dim),

            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        )

    def forward(self, x):
        conv = self.conv(x)

        if self.use_res_connect:
            return x + conv
        else:
            return conv
    

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = depthwise_conv(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_1x1_bn(channel + dim, channel)
    
    def forward(self, x):
        y = x.clone()

        # Local representations
        fm_conv = self.conv1(x)
        fm_conv = self.conv2(fm_conv)
        
        # Global representations
        _, _, h, w = fm_conv.shape
        patches = rearrange(fm_conv, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        patches = self.transformer(patches)
        fm = rearrange(patches, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)
        fm = self.conv3(fm)

        # Fusion
        concat = torch.cat((fm_conv, fm), 1)
        res = self.conv4(concat)
        
        res = y + res
        
        return res

class MobileViT(nn.Module):
    def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2), constant_factor=1.22):
        super().__init__()
        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        L = [2, 4, 3]

        channels = [int(c * constant_factor) for c in channels]
        dims = [int(d * constant_factor) for d in dims]

        self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
        
        self.mvit = nn.ModuleList([])
        self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0]*2)))
        self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1]*4)))
        self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2]*4)))

        self.mv2 = nn.ModuleList([])
        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))   # Repeat
        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

        self.fc = nn.Linear(channels[-1], num_classes, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.mv2[0](x)

        x = self.mv2[1](x)
        x = self.mv2[2](x)
        x = self.mv2[3](x)      # Repeat

        x = self.mv2[4](x)
        x = self.mvit[0](x)

        x = self.mv2[5](x)
        x = self.mvit[1](x)

        x = self.mv2[6](x)
        x = self.mvit[2](x)
        x = self.conv2(x)
        
        x = torch.mean(x, dim=[2, 3])
        x = self.fc(x)
        return x


def mobilevit_xxs(num_classes=37):

    dims = [64, 80, 96]
    channels = [16, 16, 24, 24, 64, 64, 80, 80, 128, 128, 512]
    return MobileViT((256, 256), dims, channels, num_classes=num_classes, expansion=2, constant_factor=1.22)

def mobilevit_xs(num_classes=37):
    dims = [96, 120, 144]
    channels = [16, 32, 48, 48, 96, 96, 160, 160, 160, 160, 640]
    return MobileViT((256, 256), dims, channels, num_classes=num_classes, constant_factor=1)


def mobilevit_s(num_classes=37):
    dims = [144, 192, 240]
    channels = [16, 32, 64, 64, 128, 128, 256, 256, 320, 320, 1280]
    return MobileViT((256, 256), dims, channels, num_classes=num_classes, constant_factor=1)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop((256, 256), scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3)),
])
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root='train_cat', transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = datasets.ImageFolder(root='val_cat', transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

criterion = nn.CrossEntropyLoss()

def save_checkpoint(state, filename='best_model.pth'):
    torch.save(state, filename)

def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    accuracy = correct / total

    return epoch_loss, accuracy

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    accuracy = correct / total

    return epoch_loss, accuracy



cuda


In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt

checkpoint_path = 'best_model_checkpoint.pth'

num_epochs = 100

model = mobilevit_xxs(num_classes=12)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
print(count_parameters(model))
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
best_val_acc = 0.0
best_model_wts = None

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    current_lr = optimizer.param_groups[0]['lr']

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}')
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')
    print(f'Current Learning Rate: {current_lr:.6f}')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = model.state_dict()
        save_checkpoint({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'best_val_acc': best_val_acc,
            'optimizer_state_dict': optimizer.state_dict(),
        }, filename=checkpoint_path)

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.title('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train')
plt.plot(val_accuracies, label='Validation')
plt.title('Accuracy')
plt.legend()

plt.show()

# Normal
# 830476
# Epoch 96/200
# Train Loss: 0.7575, Train Accuracy: 0.7109
# Validation Loss: 1.0073, Validation Accuracy: 0.6708
# Current Learning Rate: 0.001000
# Epoch 196/200
# Train Loss: 0.3117, Train Accuracy: 0.8938
# Validation Loss: 1.0361, Validation Accuracy: 0.7167 (5)
# Current Learning Rate: 0.001000

# Coordinate Attention On MV2Block
# 844506
# Epoch 99/100
# Train Loss: 0.6491, Train Accuracy: 0.7734
# Validation Loss: 1.0123, Validation Accuracy: 0.7000
# Current Learning Rate: 0.001000
# Epoch 200
# Epoch 176/200
# Train Loss: 0.3923, Train Accuracy: 0.8615
# Validation Loss: 1.0543, Validation Accuracy: 0.7208 (4)
# Current Learning Rate: 0.001000

# SE On MV2Block
# 836844
# Epoch 72/200
# Train Loss: 0.7780, Train Accuracy: 0.7146
# Validation Loss: 0.9999, Validation Accuracy: 0.6792
# Current Learning Rate: 0.001000
# Epoch 137/200
# Train Loss: 0.4203, Train Accuracy: 0.8536
# Validation Loss: 1.0707, Validation Accuracy: 0.7083 (6)
# Current Learning Rate: 0.001000

# Coordinate Attention On MV2Block + 5x5 Conv (PetVision)
# 852442
# Epoch 94/100
# Train Loss: 0.5788, Train Accuracy: 0.7937
# Validation Loss: 1.0999, Validation Accuracy: 0.6917
# Current Learning Rate: 0.001000
# Epoch 200
# Epoch 197/100
# Train Loss: 0.2681, Train Accuracy: 0.9016
# Validation Loss: 1.1273, Validation Accuracy: 0.7375  (3)
# Current Learning Rate: 0.001000

# Coordinate Attention On MV2Block + 5x5 Conv + 1.22x (PetVision + 1.22x)
# 1224678
# Epoch 94/200
# Train Loss: 0.6958, Train Accuracy: 0.7495
# Validation Loss: 0.9741, Validation Accuracy: 0.6937
# Current Learning Rate: 0.001000
# Epoch 198/200
# Train Loss: 0.3077, Train Accuracy: 0.8880
# Validation Loss: 0.9529, Validation Accuracy: 0.7521 (1)
# Current Learning Rate: 0.001000

# Normal + 1.24x
# 1235319
# Epoch 95/200
# Train Loss: 0.6388, Train Accuracy: 0.7677
# Validation Loss: 0.9213, Validation Accuracy: 0.6958
# Current Learning Rate: 0.001000
# Epoch 122/200
# Train Loss: 0.4876, Train Accuracy: 0.8141
# Validation Loss: 0.8511, Validation Accuracy: 0.7375 (2)
# Current Learning Rate: 0.001000

# MobileNetv2 (0.75 Width Multiplier)
# 1370796
# Epoch 98/100
# Train Loss: 0.6753, Train Accuracy: 0.7500
# Validation Loss: 1.1615, Validation Accuracy: 0.6458
# Current Learning Rate: 0.001000
# Epoch 87/100
# Train Loss: 0.3210, Train Accuracy: 0.8870
# Validation Loss: 1.0084, Validation Accuracy: 0.7208
# Current Learning Rate: 0.001000

# MobileNetv3 Small
# 1530156
# Epoch 89/100
# Train Loss: 0.6280, Train Accuracy: 0.7750
# Validation Loss: 1.2734, Validation Accuracy: 0.6354
# Current Learning Rate: 0.001000
# Epoch 77/100
# Train Loss: 0.2575, Train Accuracy: 0.9135
# Validation Loss: 1.5066, Validation Accuracy: 0.6562
# Current Learning Rate: 0.001000