In [5]:
!pip install torch torchvision tqdm pillow
!pip install adabelief-pytorch



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
import random
from tqdm import tqdm
from adabelief_pytorch import AdaBelief

class ModifiedResNet18(nn.Module):
    def __init__(self):
        super(ModifiedResNet18, self).__init__()
        resnet = models.resnet18(pretrained=True)
        
        # Modified first conv with stride 2
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        
        # Modify layer1 and layer2 to remove downsampling
        self.layer1 = self._modify_layer(resnet.layer1, stride=1)
        self.layer2 = self._modify_layer(resnet.layer2, stride=1)
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        # Freeze parameters
        for param in self.parameters():
            param.requires_grad = False
    
    def _modify_layer(self, layer, stride):
        for block in layer:
            block.conv1.stride = (stride, stride)
            block.conv2.stride = (1, 1)
            if block.downsample is not None:
                block.downsample[0].stride = (stride, stride)
        return layer
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        return x  # Output: [B, 512, H/16, W/16]

class GRUWithSkipConnection(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GRUWithSkipConnection, self).__init__()
        self.gru = nn.GRUCell(input_dim, hidden_dim)
        self.skip_proj = nn.Linear(input_dim, hidden_dim)
        
    def forward(self, x, h):
        h_new = self.gru(x, h)
        skip = self.skip_proj(x)
        return h_new + skip

class PatternExtractor(nn.Module):
    def __init__(self, in_channels=512, hidden_dim=256, num_patterns=7, num_iterations=3):
        super(PatternExtractor, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_patterns = num_patterns
        self.num_iterations = num_iterations
        
        # Initial feature processing
        self.conv1x1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        self.positional_embedding = nn.Parameter(torch.randn(1, hidden_dim, 1, 1))
        
        # Attention gating parameter
        self.attention_gate = nn.Parameter(torch.ones(1))
        
        # GRU with skip connections
        self.grusc = GRUWithSkipConnection(hidden_dim, hidden_dim)
        
        # Pattern networks
        self.pattern_init = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_patterns * hidden_dim)
        )
        
        # Attention networks
        self.query_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.key_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # Initial feature processing
        x = self.conv1x1(x)  # [B, hidden_dim, H, W]
        x = x + self.positional_embedding
        h, w = x.shape[-2:]
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, hidden_dim]
        
        # Initialize patterns
        patterns = self.pattern_init(x.mean(1))  # [B, num_patterns * hidden_dim]
        patterns = patterns.view(batch_size, self.num_patterns, self.hidden_dim)
        
        # Iterative pattern refinement
        h_state = torch.zeros(batch_size * self.num_patterns, self.hidden_dim).to(x.device)
        
        for _ in range(self.num_iterations):
            # Self-attention
            q = self.query_net(patterns.reshape(-1, self.hidden_dim))
            k = self.key_net(x.reshape(-1, self.hidden_dim))
            
            # Compute attention scores
            attn = torch.matmul(q.view(batch_size, self.num_patterns, -1),
                              k.view(batch_size, -1, self.hidden_dim).transpose(1, 2))
            attn = F.softmax(attn / (self.hidden_dim ** 0.5), dim=-1)
            
            # Apply attention gating
            attn = attn * torch.sigmoid(self.attention_gate)
            
            # Update patterns
            context = torch.bmm(attn, x)  # [B, num_patterns, hidden_dim]
            
            # GRU update
            context_flat = context.reshape(-1, self.hidden_dim)
            h_state = self.grusc(context_flat, h_state)
            patterns = h_state.view(batch_size, self.num_patterns, self.hidden_dim)
        
        return patterns, attn.view(batch_size, self.num_patterns, h, w)

class PairwiseMatchingModule(nn.Module):
    def __init__(self, hidden_dim):
        super(PairwiseMatchingModule, self).__init__()
        self.matching_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, query_features, support_features):
        # Ensure both inputs are 2D
        if query_features.dim() > 2:
            query_features = query_features.mean(1)  # Average across any extra dimensions
        if support_features.dim() > 2:
            support_features = support_features.mean(1)  # Average across any extra dimensions
            
        batch_size = query_features.size(0)
        
        # Reshape for pairwise comparison
        query_expanded = query_features.unsqueeze(1)  # [B, 1, H]
        
        # Combine features
        combined = torch.cat([
            query_expanded.expand(-1, support_features.size(0), -1),  # [B, S, H]
            support_features.unsqueeze(0).expand(batch_size, -1, -1)  # [B, S, H]
        ], dim=-1)
        
        # Get scores
        scores = self.matching_net(combined.view(-1, combined.size(-1)))
        return scores.view(batch_size, -1)

class MTUNetPlusPlus(nn.Module):
    def __init__(self, hidden_dim=256):
        super(MTUNetPlusPlus, self).__init__()
        self.backbone = ModifiedResNet18()
        self.pattern_extractor = PatternExtractor(in_channels=512, hidden_dim=hidden_dim)
        self.matching_module = PairwiseMatchingModule(hidden_dim)
        
    def forward(self, query_img, support_imgs=None, return_features=False):
        query_features = self.backbone(query_img)
        query_patterns, query_attn = self.pattern_extractor(query_features)
        
        if support_imgs is None or return_features:
            return query_patterns, query_attn
        
        support_features = self.backbone(support_imgs)
        support_patterns, _ = self.pattern_extractor(support_features)
        
        # Ensure patterns are properly averaged before matching
        query_features = query_patterns.mean(1)   # Average across patterns
        support_features = support_patterns.mean(1)  # Average across patterns
        
        scores = self.matching_module(query_features, support_features)
        
        return scores

class ModelTrainer:
    def __init__(self, model, train_loader, val_loader, optimizer, device,
                 checkpoint_dir='checkpoints'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        self.best_accuracy = 0.0
        self.best_epoch = 0
    
    # Modify train_episode method
    def train_episode(self, support_images, support_labels, query_images, query_labels):
        self.model.train()
        
        # Reshape tensors to correct dimensions
        support_images = support_images.squeeze(0)  # Remove extra batch dimension
        query_images = query_images.squeeze(0)
        support_labels = support_labels.squeeze(0)
        query_labels = query_labels.squeeze(0)
        
        # Move to device
        support_images = support_images.to(self.device)
        support_labels = support_labels.to(self.device)
        query_images = query_images.to(self.device)
        query_labels = query_labels.to(self.device)
        
        self.optimizer.zero_grad()
        scores = self.model(query_images, support_images)
        loss = F.cross_entropy(scores, query_labels)
        
        loss.backward()
        self.optimizer.step()
        
        predictions = scores.max(1)[1]
        accuracy = (predictions == query_labels).float().mean()
        
        return loss.item(), accuracy.item()
        
    
    def validate_episode(self, support_images, support_labels, query_images, query_labels):
        self.model.eval()
        
        with torch.no_grad():
            support_images = support_images.to(self.device)
            support_labels = support_labels.to(self.device)
            query_images = query_images.to(self.device)
            query_labels = query_labels.to(self.device)
            
            scores = self.model(query_images, support_images)
            loss = F.cross_entropy(scores, query_labels)
            
            predictions = scores.max(1)[1]
            accuracy = (predictions == query_labels).float().mean()
        
        return loss.item(), accuracy.item()
    
    def train_epoch(self, epoch):
        total_loss = 0
        total_accuracy = 0
        
        pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader))
        for batch_idx, (support_imgs, support_labels, query_imgs, query_labels) in pbar:
            loss, accuracy = self.train_episode(support_imgs, support_labels, query_imgs, query_labels)
            
            total_loss += loss
            total_accuracy += accuracy
            
            pbar.set_description(f'Epoch {epoch} | Loss: {loss:.4f} | Acc: {accuracy:.4f}')
        
        return total_loss / len(self.train_loader), total_accuracy / len(self.train_loader)
    
    def validate(self):
        total_loss = 0
        total_accuracy = 0
        
        for support_imgs, support_labels, query_imgs, query_labels in tqdm(self.val_loader):
            loss, accuracy = self.validate_episode(support_imgs, support_labels, query_imgs, query_labels)
            total_loss += loss
            total_accuracy += accuracy
        
        return total_loss / len(self.val_loader), total_accuracy / len(self.val_loader)
    
    def save_checkpoint(self, epoch, accuracy):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'accuracy': accuracy
        }
        
        path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save(checkpoint, path)
        
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.best_epoch = epoch
            best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
            torch.save(checkpoint, best_path)
    
    def _adjust_learning_rates(self):
        for param_group in self.optimizer.param_groups:
            if any(p in self.model.backbone.parameters() for p in param_group['params']):
                param_group['lr'] = 1e-5
            else:
                param_group['lr'] = 1e-4
    
    def train_phase(self, num_epochs, phase_name="Training"):
        for epoch in range(num_epochs):
            print(f"\n{phase_name} - Epoch {epoch+1}/{num_epochs}")
            
            train_loss, train_acc = self.train_epoch(epoch)
            print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
            
            val_loss, val_acc = self.validate()
            print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
            
            self.save_checkpoint(epoch, val_acc)
            
            if phase_name == "Initial Training" and epoch == 40:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] *= 0.1
       
    def train_full(self):
        """Complete two-phase training process"""
        # Phase 1: Initial training
        print("Starting Phase 1: Initial Training")
        self.train_phase(num_epochs=150, phase_name="Initial Training")
        
        # Phase 2: Fine-tuning
        print("\nStarting Phase 2: Fine-tuning")
        self._adjust_learning_rates()
        self.train_phase(num_epochs=20, phase_name="Fine-tuning")
        
        print(f"\nTraining completed! Best accuracy: {self.best_accuracy:.4f} at epoch {self.best_epoch}")

def split_classes(root_dir, val_split=0.2, random_seed=42):
    """
    Split classes into training and validation sets.
    
    Args:
        root_dir: Path to the data directory
        val_split: Proportion of classes to use for validation
        random_seed: Random seed for reproducibility
    
    Returns:
        train_classes: List of class names for training
        val_classes: List of class names for validation
    """
    random.seed(random_seed)
    
    # Get all class folders
    classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
    
    # Randomly shuffle classes
    random.shuffle(classes)
    
    # Split classes
    split_idx = int(len(classes) * (1 - val_split))
    train_classes = classes[:split_idx]
    val_classes = classes[split_idx:]
    
    return train_classes, val_classes

class EpisodeDataset(Dataset):
    def __init__(self, root_dir, allowed_classes, transform=None, n_way=2, n_support=5, n_query=15):
        """
        Dataset class for few-shot learning episodes
        
        Args:
            root_dir: Root directory containing class folders
            allowed_classes: List of class names this dataset can use
            transform: Image transformations
            n_way: Number of classes per episode
            n_support: Number of support examples per class
            n_query: Number of query examples per class
        """
        self.root_dir = root_dir
        self.transform = transform
        self.n_way = n_way
        self.n_support = n_support
        self.n_query = n_query
        
        # Only use the allowed classes
        self.classes = allowed_classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Build image paths dictionary
        self.images_by_class = {}
        for cls in self.classes:
            class_path = os.path.join(root_dir, cls)
            self.images_by_class[cls] = [
                os.path.join(class_path, img) 
                for img in os.listdir(class_path) 
                if img.endswith(('.jpg', '.jpeg', '.png'))
            ]
    
    def __len__(self):
        return 1000  # Number of episodes per epoch
    
    def __getitem__(self, idx):
        # Rest of the __getitem__ method remains the same
        episode_classes = random.sample(self.classes, self.n_way)
        
        support_images = []
        support_labels = []
        query_images = []
        query_labels = []
        
        for label, cls in enumerate(episode_classes):
            class_images = self.images_by_class[cls]
            selected_images = random.sample(class_images, self.n_support + self.n_query)
            
            for img_path in selected_images[:self.n_support]:
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                support_images.append(image)
                support_labels.append(label)
            
            for img_path in selected_images[self.n_support:]:
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                query_images.append(image)
                query_labels.append(label)
        
        support_images = torch.stack(support_images)
        support_labels = torch.tensor(support_labels)
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, support_labels, query_images, query_labels

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((80, 80)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Split classes into train and validation sets
    data_path = '/kaggle/input/ham10000-and-gan/synthetic_images'  # Update this to your actual data path
    train_classes, val_classes = split_classes(data_path)
    
    print(f"Number of training classes: {len(train_classes)}")
    print(f"Number of validation classes: {len(val_classes)}")
    
    # Create datasets with respective class splits
    train_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=train_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    val_dataset = EpisodeDataset(
        root_dir=data_path,
        allowed_classes=val_classes,
        transform=transform,
        n_way=2,
        n_support=5,
        n_query=15
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
    # Rest of the main function remains the same
    model = MTUNetPlusPlus(hidden_dim=256).to(device)
    optimizer = AdaBelief(model.parameters(), lr=1e-4)
    
    trainer = ModelTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device
    )
    
    trainer.train_full()

if __name__ == '__main__':
    main()

Using device: cuda
Number of training classes: 5
Number of validation classes: 2
[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[

Epoch 0 | Loss: 2.2855 | Acc: 0.5000:  10%|█         | 100/1000 [00:21<02:36,  5.76it/s]