In [1]:
!pip install torch torchvision tqdm pillow
!pip install adabelief-pytorch

Collecting adabelief-pytorch
  Downloading adabelief_pytorch-0.2.1-py3-none-any.whl.metadata (616 bytes)
Downloading adabelief_pytorch-0.2.1-py3-none-any.whl (5.8 kB)
Installing collected packages: adabelief-pytorch
Successfully installed adabelief-pytorch-0.2.1


In [2]:
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
from torch.optim import AdamW
from adabelief_pytorch import AdaBelief
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Modified backbone to ensure consistent output size
class ModifiedResNet18(nn.Module):
    def __init__(self):
        super(ModifiedResNet18, self).__init__()
        # Load pretrained ResNet-18
        resnet = models.resnet18(pretrained=True)
        
        # Modify first conv layer to 3x3
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        # Remove first maxpool to maintain spatial dimensions
        # Keep other layers but remove final FC
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        return x  # Output will be [batch_size, 512, H/16, W/16]

def sample_episode(dataset, n_way=2, n_support=5, n_query=15):
    """
    Samples an episode for few-shot learning
    
    Args:
        dataset: Dataset to sample from (can be Dataset or Subset)
        n_way: Number of classes per episode
        n_support: Number of support examples per class
        n_query: Number of query examples per class
    
    Returns:
        support_images: Tensor of support set images
        support_labels: Tensor of support set labels
        query_images: Tensor of query set images
        query_labels: Tensor of query set labels
    """
    # Handle both Dataset and Subset cases
    if isinstance(dataset, torch.utils.data.Subset):
        original_dataset = dataset.dataset
        indices = dataset.indices
        # Get labels for the subset
        labels = [original_dataset.labels[i] for i in indices]
    else:
        labels = dataset.labels
        indices = range(len(dataset))
    
    # Get all available classes
    all_classes = sorted(list(set(labels)))
    
    # Randomly sample n_way classes
    selected_classes = random.sample(all_classes, n_way)
    
    # Initialize lists to store support and query examples
    support_images = []
    support_labels = []
    query_images = []
    query_labels = []
    
    # For each selected class
    for label_idx, class_label in enumerate(selected_classes):
        # Get all indices for this class
        class_indices = [i for i, (idx, label) in enumerate(zip(indices, labels)) if label == class_label]
        
        # Ensure we have enough examples
        if len(class_indices) < n_support + n_query:
            n_query = max(1, len(class_indices) - n_support)  # Ensure at least 1 query example
        
        # Sample support and query indices
        selected_indices = random.sample(class_indices, n_support + n_query)
        support_indices = selected_indices[:n_support]
        query_indices = selected_indices[n_support:n_support + n_query]
        
        # Get support examples
        for idx in support_indices:
            if isinstance(dataset, torch.utils.data.Subset):
                image, _ = dataset[idx]
            else:
                image, _ = dataset[indices[idx]]
            support_images.append(image)
            support_labels.append(label_idx)
        
        # Get query examples
        for idx in query_indices:
            if isinstance(dataset, torch.utils.data.Subset):
                image, _ = dataset[idx]
            else:
                image, _ = dataset[indices[idx]]
            query_images.append(image)
            query_labels.append(label_idx)
    
    # Convert to tensors
    support_images = torch.stack(support_images)
    support_labels = torch.tensor(support_labels)
    query_images = torch.stack(query_images)
    query_labels = torch.tensor(query_labels)
    
    return support_images, support_labels, query_images, query_labels



       

class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(AttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels // 8, in_channels, kernel_size=1)
        
    def forward(self, x):
        attention = F.relu(self.conv1(x))
        attention = torch.sigmoid(self.conv2(attention))
        return x * attention


class DynamicPatternExtractor(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, num_patterns=7, num_iterations=3):
        super(DynamicPatternExtractor, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_patterns = num_patterns
        self.num_iterations = num_iterations
        self.input_dim = input_dim
        
        # Adaptive pooling for handling variable input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Dynamic pattern initialization network
        self.pattern_init = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_patterns * input_dim)
        )
        
        # Use PyTorch's built-in GRUCell instead of custom implementation
        self.gru = nn.GRUCell(input_dim + input_dim, hidden_dim)  # Combined feature dimensions
        
        # Pattern attention network
        self.pattern_attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),  # Single attention score per pattern
            nn.Sigmoid()
        )
        
        # Pattern update network
        self.pattern_update = nn.Sequential(
            nn.Linear(hidden_dim + input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
        
        # Complexity estimation network
        self.complexity_estimator = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def init_patterns(self, x):
        """Initialize patterns based on input features"""
        batch_size = x.size(0)
        init_features = self.pattern_init(x)
        return init_features.view(batch_size, self.num_patterns, self.input_dim)
    
    def estimate_complexity(self, x):
        """Estimate image complexity to adjust pattern refinement"""
        complexity = self.complexity_estimator(x)
        return complexity.squeeze(-1)  # Remove last dimension for broadcasting
    
    def refine_patterns(self, patterns, features, complexity):
        """Refine patterns based on image complexity"""
        batch_size = patterns.size(0)
        h = None
        
        for _ in range(self.num_iterations):
            # Calculate pattern attention weights
            if h is not None:
                # Reshape h for attention calculation
                h_reshaped = h.view(batch_size * self.num_patterns, -1)
                attention = self.pattern_attention(h_reshaped)
                attention = attention.view(batch_size, self.num_patterns, 1)
                
                # Apply attention to patterns
                attended_patterns = patterns * attention
            else:
                attended_patterns = patterns
            
            # Prepare features for combination
            expanded_features = features.unsqueeze(1).expand(-1, self.num_patterns, -1)
            
            # Combine features and patterns
            combined = torch.cat([expanded_features, attended_patterns], dim=-1)
            
            # Update hidden state
            combined_flat = combined.view(batch_size * self.num_patterns, -1)
            if h is None:
                h = torch.zeros(batch_size * self.num_patterns, self.hidden_dim).to(patterns.device)
            
            h = self.gru(combined_flat, h)
            
            # Prepare inputs for pattern update
            h_reshaped = h.view(batch_size, self.num_patterns, -1)
            update_input = torch.cat([h_reshaped, attended_patterns], dim=-1)
            
            # Generate and apply updates
            updates = self.pattern_update(update_input)
            complexity_expanded = complexity.view(batch_size, 1, 1).expand(-1, self.num_patterns, self.input_dim)
            patterns = patterns + complexity_expanded * updates
            
        return patterns
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # Handle spatial dimensions
        if len(x.shape) == 4:  # If input has spatial dimensions
            x = self.adaptive_pool(x)
            x = x.view(batch_size, -1)
        
        # Estimate image complexity
        complexity = self.estimate_complexity(x)
        
        # Initialize patterns
        patterns = self.init_patterns(x)
        
        # Refine patterns based on complexity
        refined_patterns = self.refine_patterns(patterns, x, complexity)
        
        return refined_patterns, complexity

class MTUNetPlusPlus(nn.Module):
    def __init__(self, num_classes):
        super(MTUNetPlusPlus, self).__init__()
        self.backbone = ModifiedResNet18()
        self.pattern_extractor = DynamicPatternExtractor(input_dim=512)
        self.attention = AttentionModule(512)
        
        # Add global pooling before classifier
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Add complexity-aware feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(512 * 2, 512),  # Double input for concatenated features
            nn.ReLU(),
            nn.Linear(512, 512)
        )
        
        self.classifier = nn.Linear(512, num_classes)
    
    def forward(self, x, return_features=False):
        # Extract features using modified ResNet-18
        features = self.backbone(x)
        
        # Global average pooling for pattern extraction
        pooled_features = self.global_pool(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        
        # Extract patterns and complexity
        patterns, complexity = self.pattern_extractor(pooled_features)
        
        # Apply attention
        attended_features = self.attention(features)
        
        # Global average pooling for classification
        final_features = self.global_pool(attended_features)
        final_features = final_features.view(final_features.size(0), -1)
        
        # Complexity-aware feature fusion
        pattern_features = torch.mean(patterns, dim=1)  # Average patterns
        fused_features = torch.cat([final_features, pattern_features], dim=1)
        fused_features = self.fusion(fused_features)
        
        # Classification
        logits = self.classifier(fused_features)
        
        if return_features:
            return logits, fused_features, patterns, complexity
        return logits

def train_backbone(model, train_loader, num_epochs, device):
    """Train backbone CNN on the medical dataset"""
    optimizer = AdaBelief(model.backbone.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(images)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
        scheduler.step()

def train_attention(model, train_loader, num_epochs, device):
    """Train attention module independently"""
    optimizer = AdaBelief(model.attention.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Freeze backbone and pattern extractor
    for param in model.backbone.parameters():
        param.requires_grad = False
    for param in model.pattern_extractor.parameters():
        param.requires_grad = False
        
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(images)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
        scheduler.step()

def train_fewshot(model, train_loader, num_epochs, device):
    """Train few-shot classifier with reduced episodes per epoch for quick testing"""
    optimizer = AdaBelief([
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.pattern_extractor.parameters(), 'lr': 1e-5},
        {'params': model.attention.parameters(), 'lr': 1e-4},
        {'params': model.classifier.parameters(), 'lr': 1e-4}
    ])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    for epoch in range(num_epochs):
        model.train()
        episodes_per_epoch = 10  # Reduced from 500 for quick testing
        total_loss = 0
        
        for episode in range(episodes_per_epoch):
            try:
                # Sample episode
                support_images, support_labels, query_images, query_labels = sample_episode(
                    train_loader.dataset, n_way=2, n_support=5, n_query=15
                )
                
                # Move to device
                support_images, support_labels = support_images.to(device), support_labels.to(device)
                query_images, query_labels = query_images.to(device), query_labels.to(device)
                
                optimizer.zero_grad()
                
                # Get features
                support_features = model(support_images, return_features=True)[1]
                query_features = model(query_images, return_features=True)[1]
                
                # Compute prototypes
                prototypes = compute_prototypes(support_features, support_labels)
                
                # Compute distances
                logits = compute_distances(query_features, prototypes)
                
                # Compute loss
                loss = F.cross_entropy(logits, query_labels)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                if episode % 2 == 0:  # Print more frequently for testing
                    print(f'Epoch [{epoch+1}/{num_epochs}], Episode [{episode+1}/{episodes_per_epoch}], '
                          f'Loss: {loss.item():.4f}')
                    
            except Exception as e:
                print(f"Error in episode {episode}: {str(e)}")
                continue
        
        avg_loss = total_loss / episodes_per_epoch
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
        scheduler.step()

class HAM10000Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images organized in class folders
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))  # Get class folders
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        # Load all image paths and labels
        for class_name in self.classes:
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    if img_name.endswith(('.jpg', '.jpeg', '.png')):
                        self.images.append(os.path.join(class_path, img_name))
                        self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def train_epoch(model, discriminator, train_loader, optimizer_G, optimizer_D, device):
    model.train()
    discriminator.train()
    total_loss = 0
    
    for batch_idx, (images, labels) in enumerate(tqdm(train_loader)):
        images, labels = images.to(device), labels.to(device)
        batch_size = images.size(0)
        
        # Train Discriminator
        optimizer_D.zero_grad()
        features = model.feature_extractor(images)
        d_real = discriminator(features.detach())
        # Use proper target shape
        real_labels = torch.ones(batch_size, 1).to(device)
        d_loss_real = F.binary_cross_entropy(d_real, real_labels)
        d_loss_real.backward()
        optimizer_D.step()
        
        # Train Generator (Feature Extractor) and Classifier
        optimizer_G.zero_grad()
        features = model.feature_extractor(images)
        d_fake = discriminator(features)
        g_loss = F.binary_cross_entropy(d_fake, real_labels)
        
        # Classification loss
        logits = model(images)
        cls_loss = F.cross_entropy(logits, labels)
        
        # Combined loss
        total_g_loss = cls_loss + 0.1 * g_loss
        total_g_loss.backward()
        optimizer_G.step()
        
        total_loss += total_g_loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Batch [{batch_idx}/{len(train_loader)}], '
                  f'Loss: {total_g_loss.item():.4f}, '
                  f'Class Loss: {cls_loss.item():.4f}, '
                  f'G Loss: {g_loss.item():.4f}')
    
    return total_loss / len(train_loader)

def evaluate_episodes(model, dataset, num_episodes=2000, n_way=2, k_shot=1, n_query=15, device='cuda'):
    """
    Evaluate model on n-way k-shot tasks
    Args:
        model: trained model
        dataset: dataset to sample episodes from (can be Dataset or Subset)
        num_episodes: number of episodes to evaluate
        n_way: number of classes per episode
        k_shot: number of support examples per class
        n_query: number of query examples per class
        device: device to run evaluation on
    """
    model.eval()
    accuracies = []
    
    # Handle both Dataset and Subset cases
    if isinstance(dataset, torch.utils.data.Subset):
        original_dataset = dataset.dataset
        indices = dataset.indices
        # Get labels for the subset
        labels = [original_dataset.labels[i] for i in indices]
    else:
        labels = dataset.labels
        indices = range(len(dataset))
    
    # Get all available classes
    all_classes = sorted(list(set(labels)))
    
    if len(all_classes) < n_way:
        print(f"Warning: Only {len(all_classes)} classes available, but {n_way} classes requested.")
        n_way = len(all_classes)
    
    with torch.no_grad():
        for episode in range(num_episodes):
            try:
                # Randomly sample n classes
                episode_classes = random.sample(all_classes, n_way)
                
                # Get indices for each class
                support_indices = []
                query_indices = []
                
                for class_idx in episode_classes:
                    # Get all indices for this class in the subset
                    class_indices = [i for i, (idx, label) in enumerate(zip(indices, labels)) if label == class_idx]
                    
                    if len(class_indices) < k_shot + n_query:
                        # If not enough examples, use what we have
                        available = len(class_indices)
                        k_shot_actual = min(k_shot, available - 1)
                        n_query_actual = min(n_query, available - k_shot_actual)
                    else:
                        k_shot_actual = k_shot
                        n_query_actual = n_query
                    
                    # Sample k examples for support set
                    support = random.sample(class_indices, k_shot_actual)
                    # Sample remaining examples for query set
                    remaining = list(set(class_indices) - set(support))
                    query = random.sample(remaining, n_query_actual)
                    
                    support_indices.extend(support)
                    query_indices.extend(query)
                
                # Prepare support and query sets
                support_images = torch.stack([dataset[idx][0] for idx in support_indices]).to(device)
                support_labels = torch.tensor([labels[idx] for idx in support_indices]).to(device)
                query_images = torch.stack([dataset[idx][0] for idx in query_indices]).to(device)
                query_labels = torch.tensor([labels[idx] for idx in query_indices]).to(device)
                
                # Get model predictions
                support_features = model(support_images, return_features=True)[1]
                query_features = model(query_images, return_features=True)[1]
                
                # Calculate prototypes for each class
                prototypes = {}
                for cls in episode_classes:
                    cls_mask = support_labels == cls
                    cls_features = support_features[cls_mask]
                    if len(cls_features) > 0:  # Ensure we have features for this class
                        prototypes[cls] = cls_features.mean(0)
                
                # Calculate distances to prototypes
                accuracies_episode = []
                for i, query_feat in enumerate(query_features):
                    distances = {cls: torch.norm(query_feat - proto) for cls, proto in prototypes.items()}
                    if distances:  # Ensure we have distances to compute
                        predicted_cls = min(distances, key=distances.get)
                        correct = (predicted_cls == query_labels[i].item())
                        accuracies_episode.append(correct)
                
                # Calculate accuracy for this episode
                if accuracies_episode:  # Ensure we have accuracies to compute
                    accuracy = sum(accuracies_episode) / len(accuracies_episode)
                    accuracies.append(accuracy)
                
                if (episode + 1) % 100 == 0:
                    print(f'Episode {episode + 1}/{num_episodes}, Running Avg Accuracy: {np.mean(accuracies):.4f}')
                    
            except Exception as e:
                print(f"Error in episode {episode}: {str(e)}")
                continue
    
    if accuracies:
        final_accuracy = np.mean(accuracies)
        print(f'\nFinal Average Accuracy over {len(accuracies)} episodes: {final_accuracy:.4f}')
        return final_accuracy
    else:
        print("\nNo valid episodes completed. Please check dataset size and parameters.")
        return 0.0

def visualize_patterns_and_features(model, dataset, device, n_support=2, n_query=2, n_patterns=7, save_path=None):
    """
    Visualize patterns and features for a sampled task
    Args:
        model: trained model
        dataset: dataset to sample from
        device: device to run model on
        n_support: number of support images to show
        n_query: number of query images to show
        n_patterns: number of pattern slots (default 7)
        save_path: path to save the visualization
    """
    model.eval()
    
    # Sample images
    support_images, support_labels, query_images, query_labels = sample_episode(
        dataset, n_way=2, n_support=n_support, n_query=n_query
    )
    
    # Move to device
    support_images = support_images.to(device)
    query_images = query_images.to(device)
    
    with torch.no_grad():
        # Get patterns and features for support set
        _, support_features, support_patterns = model(support_images, return_features=True)
        # Get patterns and features for query set
        _, query_features, query_patterns = model(query_images, return_features=True)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(4, n_patterns + 1)
    
    # Plot support images and their patterns
    for i in range(n_support):
        # Original support image
        ax = fig.add_subplot(gs[i, 0])
        img = support_images[i].cpu().permute(1, 2, 0)
        img = (img - img.min()) / (img.max() - img.min())
        ax.imshow(img)
        ax.set_title(f'Support Image {i+1}')
        ax.axis('off')
        
        # Patterns for this support image
        for j in range(n_patterns):
            ax = fig.add_subplot(gs[i, j+1])
            pattern = support_patterns[i, j].reshape(int(np.sqrt(support_patterns.size(-1))), -1)
            pattern = pattern.cpu()
            pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min())
            ax.imshow(pattern, cmap='viridis')
            ax.set_title(f'Pattern {j+1}')
            ax.axis('off')
    
    # Plot query images and their patterns
    for i in range(n_query):
        # Original query image
        ax = fig.add_subplot(gs[i+n_support, 0])
        img = query_images[i].cpu().permute(1, 2, 0)
        img = (img - img.min()) / (img.max() - img.min())
        ax.imshow(img)
        ax.set_title(f'Query Image {i+1}')
        ax.axis('off')
        
        # Patterns for this query image
        for j in range(n_patterns):
            ax = fig.add_subplot(gs[i+n_support, j+1])
            pattern = query_patterns[i, j].reshape(int(np.sqrt(query_patterns.size(-1))), -1)
            pattern = pattern.cpu()
            pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min())
            ax.imshow(pattern, cmap='viridis')
            ax.set_title(f'Pattern {j+1}')
            ax.axis('off')
    
    # Plot average patterns
    ax = fig.add_subplot(gs[3, :])
    avg_patterns = torch.cat([support_patterns, query_patterns], dim=0).mean(dim=0)
    avg_patterns = avg_patterns.reshape(n_patterns, int(np.sqrt(avg_patterns.size(-1))), -1)
    avg_patterns = avg_patterns.cpu()
    avg_patterns = (avg_patterns - avg_patterns.min()) / (avg_patterns.max() - avg_patterns.min())
    
    # Create a horizontal stack of average patterns
    combined_patterns = torch.hstack([p for p in avg_patterns])
    ax.imshow(combined_patterns, cmap='viridis')
    ax.set_title('Overall Average Patterns')
    ax.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Visualization saved to {save_path}")
    
    plt.show()

def visualize_dataset_patterns(model, dataset_name, dataset, device, num_samples=3, save_path=None):
    """
    Visualize patterns from multiple samples in a dataset
    Args:
        model: trained model
        dataset_name: name of the dataset (for title)
        dataset: dataset to sample from
        device: device to run model on
        num_samples: number of different samples to show
        save_path: path to save the visualization
    """
    model.eval()
    
    fig = plt.figure(figsize=(15, 5 * num_samples))
    
    for sample_idx in range(num_samples):
        # Sample a single image
        idx = np.random.randint(len(dataset))
        image, label = dataset[idx]
        image = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            # Get patterns
            _, _, patterns = model(image, return_features=True)
        
        # Plot original image
        ax = plt.subplot(num_samples, 8, sample_idx * 8 + 1)
        img = image[0].cpu().permute(1, 2, 0)
        img = (img - img.min()) / (img.max() - img.min())
        ax.imshow(img)
        ax.set_title(f'Sample {sample_idx + 1}')
        ax.axis('off')
        
        # Plot individual patterns
        for j in range(7):
            ax = plt.subplot(num_samples, 8, sample_idx * 8 + j + 2)
            pattern = patterns[0, j].reshape(int(np.sqrt(patterns.size(-1))), -1)
            pattern = pattern.cpu()
            pattern = (pattern - pattern.min()) / (pattern.max() - pattern.min())
            ax.imshow(pattern, cmap='viridis')
            ax.set_title(f'Pattern {j + 1}')
            ax.axis('off')
    
    plt.suptitle(f'Pattern Visualization - {dataset_name} Dataset', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Visualization saved to {save_path}")
    
    plt.show()

# Function to test the visualizations
def test_visualizations(model, dataset, device):
    print("Generating pattern visualizations...")
    
    # Create output directory if it doesn't exist
    import os
    os.makedirs('visualizations', exist_ok=True)
    
    # Generate and save visualizations
    visualize_patterns_and_features(
        model, 
        dataset, 
        device,
        save_path='visualizations/patterns_and_features.png'
    )
    
    visualize_dataset_patterns(
        model,
        'HAM10000',
        dataset,
        device,
        save_path='visualizations/dataset_patterns.png'
    )
    
    print("Visualization complete! Check the 'visualizations' folder.")
    
def main():
    # Set device and data transforms
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((80, 80)),  # As per specifications
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Create dataset and dataloader
    dataset = HAM10000Dataset(
        root_dir='/kaggle/input/ham10000-and-gan/synthetic_images',
        transform=transform
    )
    model = MTUNetPlusPlus(num_classes=7).to(device)
    
    # Sequential training
    print("Training backbone...")
    train_backbone(model, DataLoader(dataset, batch_size=32, shuffle=True), num_epochs=150, device=device)
    
    print("Training attention module...")
    train_attention(model, DataLoader(dataset, batch_size=32, shuffle=True), num_epochs=20, device=device)
    
    print("Training few-shot classifier...")
    train_fewshot(model, DataLoader(dataset, batch_size=32, shuffle=True), num_epochs=20, device=device)
    
    # Final evaluation
    print("Evaluating model...")
    evaluate_episodes(model, dataset, num_episodes=2000, n_way=2, k_shot=5, n_query=15, device=device)

    
if __name__ == "__main__":
   main()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 84.3MB/s]


Training backbone...
[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief
Rectification enab



Error in episode 0: name 'compute_prototypes' is not defined
Error in episode 1: name 'compute_prototypes' is not defined
Error in episode 2: name 'compute_prototypes' is not defined
Error in episode 3: name 'compute_prototypes' is not defined
Error in episode 4: name 'compute_prototypes' is not defined
Error in episode 5: name 'compute_prototypes' is not defined
Error in episode 6: name 'compute_prototypes' is not defined
Error in episode 7: name 'compute_prototypes' is not defined
Error in episode 8: name 'compute_prototypes' is not defined
Error in episode 9: name 'compute_prototypes' is not defined
Epoch [2/20], Average Loss: 0.0000
Error in episode 0: name 'compute_prototypes' is not defined
Error in episode 1: name 'compute_prototypes' is not defined
Error in episode 2: name 'compute_prototypes' is not defined
Error in episode 3: name 'compute_prototypes' is not defined
Error in episode 4: name 'compute_prototypes' is not defined
Error in episode 5: name 'compute_prototypes' is n