In [1]:
!pip install torch torchvision tqdm pillow
!pip install adabelief-pytorch

Collecting adabelief-pytorch
  Downloading adabelief_pytorch-0.2.1-py3-none-any.whl.metadata (616 bytes)
Downloading adabelief_pytorch-0.2.1-py3-none-any.whl (5.8 kB)
Installing collected packages: adabelief-pytorch
Successfully installed adabelief-pytorch-0.2.1


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
import random
import math
from tqdm import tqdm
from adabelief_pytorch import AdaBelief

class ModifiedResNet18(nn.Module):
    def __init__(self, pretrained=True):
        super(ModifiedResNet18, self).__init__()
        resnet = models.resnet18(pretrained=pretrained)
        
        # Replace 7x7 conv with 3x3 conv as per specs
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        
        # Remove maxpool to prevent downsampling
        self.maxpool = nn.Identity()
        
        # Modify first two layers to remove downsampling
        self.layer1 = self._modify_layer(resnet.layer1, stride=1)
        self.layer2 = self._modify_layer(resnet.layer2, stride=1)
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        self.is_pretraining = True
    
    def _modify_layer(self, layer, stride):
        for block in layer:
            block.conv1.stride = (stride, stride)
            block.conv2.stride = (1, 1)
            if block.downsample is not None:
                block.downsample[0].stride = (stride, stride)
        return layer
    
    def freeze_parameters(self):
        """Freeze parameters after pretraining"""
        self.is_pretraining = False
        for param in self.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        x = self.conv1(x)  # Now using 3x3 conv
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        return x

class PretrainingModel(nn.Module):
    def __init__(self, num_classes):
        super(PretrainingModel, self).__init__()
        self.backbone = ModifiedResNet18(pretrained=True)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
    
    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class PretrainingTrainer:
    def __init__(self, model, train_loader, val_loader, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdaBelief(self.model.parameters(), lr=1e-4)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.1, patience=5
        )
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for inputs, targets in tqdm(self.train_loader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
        return total_loss / len(self.train_loader), 100. * correct / total
    
    def validate(self):
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets in self.val_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        
        return 100. * correct / total
    
    def train(self, epochs):
        best_acc = 0
        
        for epoch in range(epochs):
            train_loss, train_acc = self.train_epoch()
            val_acc = self.validate()
            
            print(f'Epoch: {epoch}')
            print(f'Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.3f}%')
            print(f'Val Acc: {val_acc:.3f}%')
            
            self.scheduler.step(val_acc)
            
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(self.model.backbone.state_dict(), 'best_backbone.pt')
        
        return best_acc

class GRUWithSkipConnection(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GRUWithSkipConnection, self).__init__()
        self.gru = nn.GRUCell(input_dim, hidden_dim)
        self.skip_proj = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x, h):
        h_new = self.gru(x, h)
        skip = self.skip_proj(x)
        return h_new + skip

class PatternExtractor(nn.Module):
    def __init__(self, in_channels=512, hidden_dim=256, num_patterns=7, num_iterations=3):
        super(PatternExtractor, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_patterns = num_patterns
        self.num_iterations = num_iterations
        
        # Initial feature processing with 1x1 conv
        self.conv1x1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        self.positional_embedding = nn.Parameter(torch.randn(1, hidden_dim, 1, 1))
        
        # Attention gating parameter
        self.attention_gate = nn.Parameter(torch.ones(1))
        
        # GRU with skip connections
        self.grusc = GRUWithSkipConnection(hidden_dim, hidden_dim)
        
        # Pattern networks
        self.pattern_init = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_patterns * hidden_dim)
        )
        
        # Networks gq and gM (3 FC layers with ReLU as per specs)
        self.gq = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.gM = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def compute_attention(self, q, k):
        """Compute attention scores with proper normalization"""
        attn = torch.matmul(q, k.transpose(-2, -1))
        attn = attn / math.sqrt(self.hidden_dim)
        return F.softmax(attn, dim=-1)
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # Initial feature processing
        x = self.conv1x1(x)  # [B, hidden_dim, H, W]
        x = x + self.positional_embedding
        h, w = x.shape[-2:]
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, hidden_dim]
        
        # Initialize patterns
        patterns = self.pattern_init(x.mean(1))
        patterns = patterns.view(batch_size, self.num_patterns, self.hidden_dim)
        
        # Iterative pattern refinement
        h_state = torch.zeros(batch_size * self.num_patterns, self.hidden_dim).to(x.device)
        
        for _ in range(self.num_iterations):
            # Apply gq and gM networks
            q = self.gq(patterns.reshape(-1, self.hidden_dim))
            k = self.gM(x.reshape(-1, self.hidden_dim))
            
            # Compute attention scores with proper normalization
            attn = self.compute_attention(
                q.view(batch_size, self.num_patterns, -1),
                k.view(batch_size, -1, self.hidden_dim)
            )
            
            # Apply attention gating
            attn = attn * torch.sigmoid(self.attention_gate)
            
            # Update patterns using attention and GRU
            context = torch.bmm(attn, x)
            context_flat = context.reshape(-1, self.hidden_dim)
            h_state = self.grusc(context_flat, h_state)
            patterns = h_state.view(batch_size, self.num_patterns, self.hidden_dim)
        
        return patterns, attn.view(batch_size, self.num_patterns, h, w)


class PairwiseMatchingModule(nn.Module):
    def __init__(self, hidden_dim):
        super(PairwiseMatchingModule, self).__init__()
        self.matching_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, query_features, support_features):
        if query_features.dim() > 2:
            query_features = query_features.mean(1)
        if support_features.dim() > 2:
            support_features = support_features.mean(1)
            
        batch_size = query_features.size(0)
        
        query_expanded = query_features.unsqueeze(1)
        
        combined = torch.cat([
            query_expanded.expand(-1, support_features.size(0), -1),
            support_features.unsqueeze(0).expand(batch_size, -1, -1)
        ], dim=-1)
        
        scores = self.matching_net(combined.view(-1, combined.size(-1)))
        return scores.view(batch_size, -1)

class MTUNetPlusPlus(nn.Module):
    def __init__(self, hidden_dim=256):
        super(MTUNetPlusPlus, self).__init__()
        self.backbone = ModifiedResNet18()
        self.pattern_extractor = PatternExtractor(in_channels=512, hidden_dim=hidden_dim)
        self.matching_module = PairwiseMatchingModule(hidden_dim)
        
    def forward(self, query_img, support_imgs=None, return_features=False):
        query_features = self.backbone(query_img)
        query_patterns, query_attn = self.pattern_extractor(query_features)
        
        if support_imgs is None or return_features:
            return query_patterns, query_attn
        
        support_features = self.backbone(support_imgs)
        support_patterns, _ = self.pattern_extractor(support_features)
        
        query_features = query_patterns.mean(1)
        support_features = support_patterns.mean(1)
        
        scores = self.matching_module(query_features, support_features)
        
        return scores


class ModelTrainer:
    def __init__(self, model, train_loader, val_loader, optimizer, device,
                 episodes_per_epoch=500, checkpoint_dir='checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        self.episodes_per_epoch = episodes_per_epoch
        self.criterion = nn.CrossEntropyLoss()
        
        self.best_accuracy = 0.0
        self.best_epoch = 0
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        
        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(self.train_loader):
            # Move data to device
            support_imgs = support_imgs.squeeze(0).to(self.device)
            support_labels = support_labels.squeeze(0).to(self.device)
            query_imgs = query_imgs.squeeze(0).to(self.device)
            query_labels = query_labels.squeeze(0).to(self.device)
            
            self.optimizer.zero_grad()
            
            # Get similarity scores
            scores = self.model(query_imgs, support_imgs)
            
            # Calculate loss
            loss = self.criterion(scores, query_labels)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Calculate accuracy
            _, predicted = scores.max(1)
            correct = predicted.eq(query_labels).sum().item()
            
            total_loss += loss.item()
            total_correct += correct
            total_samples += query_labels.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}, '
                      f'Accuracy: {100.0 * correct / query_labels.size(0):.2f}%')
        
        avg_loss = total_loss / len(self.train_loader)
        avg_acc = 100.0 * total_correct / total_samples
        
        return avg_loss, avg_acc
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for support_imgs, support_labels, query_imgs, query_labels in self.val_loader:
                # Move data to device
                support_imgs = support_imgs.squeeze(0).to(self.device)
                support_labels = support_labels.squeeze(0).to(self.device)
                query_imgs = query_imgs.squeeze(0).to(self.device)
                query_labels = query_labels.squeeze(0).to(self.device)
                
                # Get similarity scores
                scores = self.model(query_imgs, support_imgs)
                
                # Calculate loss
                loss = self.criterion(scores, query_labels)
                
                # Calculate accuracy
                _, predicted = scores.max(1)
                correct = predicted.eq(query_labels).sum().item()
                
                total_loss += loss.item()
                total_correct += correct
                total_samples += query_labels.size(0)
        
        avg_loss = total_loss / len(self.val_loader)
        avg_acc = 100.0 * total_correct / total_samples
        
        return avg_loss, avg_acc
    
    def save_checkpoint(self, epoch, val_acc):
        """Save model checkpoint if validation accuracy improves"""
        if val_acc > self.best_accuracy:
            self.best_accuracy = val_acc
            self.best_epoch = epoch
            
            if not os.path.exists(self.checkpoint_dir):
                os.makedirs(self.checkpoint_dir)
            
            checkpoint_path = os.path.join(self.checkpoint_dir, f'best_model_epoch_{epoch}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'val_accuracy': val_acc,
            }, checkpoint_path)
            print(f'Saved checkpoint: {checkpoint_path}')
    
    def train_phase(self, num_epochs, phase_name="Training"):
        for epoch in range(num_epochs):
            print(f"\n{phase_name} - Epoch {epoch+1}/{num_epochs}")
            
            # Learning rate schedules as per specs
            if phase_name == "Initial Training" and epoch == 40:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] *= 0.1
            elif phase_name == "Fine-tuning" and epoch == 10:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] *= 0.1
            
            train_loss, train_acc = self.train_epoch(epoch)
            val_loss, val_acc = self.validate()
            
            print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
            
            self.save_checkpoint(epoch, val_acc)
    
    def train_full(self):
        """Complete training process according to specifications"""
        # Phase 1: Initial Training (150 epochs)
        print("Starting Phase 1: Initial Training")
        # All components start with lr=1e-4
        param_groups = [
            {'params': self.model.backbone.parameters(), 'lr': 1e-4},
            {'params': self.model.pattern_extractor.parameters(), 'lr': 1e-4},
            {'params': self.model.matching_module.parameters(), 'lr': 1e-4}
        ]
        self.optimizer = AdaBelief(param_groups)
        self.train_phase(num_epochs=150, phase_name="Initial Training")
        
        # Phase 2: Fine-tuning (20 epochs)
        print("\nStarting Phase 2: Fine-tuning")
        # Adjust learning rates for fine-tuning phase
        param_groups = [
            {'params': self.model.backbone.parameters(), 'lr': 1e-5},  # Lower lr for backbone
            {'params': self.model.pattern_extractor.parameters(), 'lr': 1e-5},  # Lower lr for pattern extractor
            {'params': self.model.matching_module.parameters(), 'lr': 1e-4}  # Original lr for other components
        ]
        self.optimizer = AdaBelief(param_groups)
        
        self.train_phase(num_epochs=20, phase_name="Fine-tuning")
        
        print(f"\nTraining completed! Best accuracy: {self.best_accuracy:.4f} at epoch {self.best_epoch}")

class EpisodeDataset(Dataset):
    def __init__(self, root_dir, allowed_classes, transform=None, n_way=2, n_support=5, n_query=15,
                 episodes_per_epoch=500):
        self.root_dir = root_dir
        self.transform = transform
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        self.episodes_per_epoch = episodes_per_epoch
        
        self.classes = allowed_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        self.images_by_class = {}
        for cls in self.classes:
            class_path = os.path.join(root_dir, cls)
            self.images_by_class[cls] = [
                os.path.join(class_path, img) 
                for img in os.listdir(class_path) 
                if img.endswith(('.jpg', '.jpeg', '.png'))
            ]
    
    def __len__(self):
        return self.episodes_per_epoch  # Configurable episodes per epoch
    
    def __getitem__(self, idx):
        episode_classes = random.sample(self.classes, self.n_way)
        
        support_images = []
        support_labels = []
        query_images = []
        query_labels = []
        
        for label, cls in enumerate(episode_classes):
            class_images = self.images_by_class[cls]
            selected_images = random.sample(class_images, self.n_support + self.n_query)
            
            for img_path in selected_images[:self.n_support]:
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                support_images.append(image)
                support_labels.append(label)
            
            for img_path in selected_images[self.n_support:]:
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                query_images.append(image)
                query_labels.append(label)
        
        support_images = torch.stack(support_images)
        support_labels = torch.tensor(support_labels)
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, support_labels, query_images, query_labels

def split_classes(root_dir, val_split=0.2, random_seed=42):
    random.seed(random_seed)
    
    classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
    
    random.shuffle(classes)
    
    split_idx = int(len(classes) * (1 - val_split))
    train_classes = classes[:split_idx]
    val_classes = classes[split_idx:]
    
    return train_classes, val_classes

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((80, 80)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Split classes
    data_path = '/kaggle/input/ham10000-and-gan/synthetic_images'  # Update this path
    train_classes, val_classes = split_classes(data_path)
    
    print(f"Number of training classes: {len(train_classes)}")
    print(f"Number of validation classes: {len(val_classes)}")
    
    # Create datasets
    train_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=train_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    val_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=val_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    # Create model
    model = MTUNetPlusPlus(hidden_dim=256).to(device)
    
    # Define parameter groups with different learning rates
    param_groups = [
        {'params': model.backbone.parameters(), 'lr': 1e-4},
        {'params': model.pattern_extractor.parameters(), 'lr': 1e-4},
        {'params': model.matching_module.parameters(), 'lr': 1e-4}
    ]
    
    # Initialize optimizer with parameter groups
    optimizer = AdaBelief(param_groups)
    
    # Create trainer
    trainer = ModelTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device,
        episodes_per_epoch=500
    )
    
    # The trainer will automatically adjust learning rates during training phases
    trainer.train_full()

if __name__ == '__main__':
    main()

Using device: cuda
Number of training classes: 5
Number of validation classes: 2


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 196MB/s]


[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief
Star