In [94]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset,Subset
from torchvision import transforms,datasets
import numpy as np
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
from torch.optim import Adam
import os
from torch.utils.data import random_split
from PIL import Image

In [95]:
class PrototypicalNetwork(nn.Module):
    def __init__(self, encoder):
        super(PrototypicalNetwork, self).__init__()
        self.encoder = encoder
        
        # Start with encoder unfrozen by default
        for param in self.encoder.parameters():
            param.requires_grad = True
    
    def forward(self, support, query):
        """
        Args:
            support: [n_way, k_shot, C, H, W]
            query: [n_way * n_query, C, H, W]
        Returns:
            logits: [n_way * n_query, n_way]
        """
        n_way, k_shot = support.shape[:2]
        
        # Flatten support and encode
        support_flat = support.view(-1, *support.shape[-3:])
        support_features = self.encoder(support_flat)  # [n_way*k_shot, feature_dim]
        
        # Encode query images
        query_features = self.encoder(query)  # [n_way*n_query, feature_dim]
        
        # Reshape support features and compute prototypes
        support_features = support_features.view(n_way, k_shot, -1)  # [n_way, k_shot, feature_dim]
        prototypes = torch.mean(support_features, dim=1)  # [n_way, feature_dim]
        
        # Compute squared Euclidean distances
        distances = torch.cdist(query_features, prototypes, p=2).pow(2)
        
        return -distances  # Negative distances as logits

In [96]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim=64, out_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, hidden_dim, kernel_size=3, padding=1),  # 28x28
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 14x14

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 7x7

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> 3x3

            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # -> 1x1
        )
        self.fc = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.encoder(x)          # Shape: (B, hidden_dim, 1, 1)
        x = x.view(x.size(0), -1)    # Flatten to (B, hidden_dim)
        return self.fc(x)           

In [97]:
encoder = Encoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.load_state_dict(torch.load("simsiam_encoder.pth", map_location=device))

  encoder.load_state_dict(torch.load("simsiam_encoder.pth", map_location=device))


<All keys matched successfully>

In [98]:
import matplotlib.pyplot as plt
import os

class TrainingVisualizer:
    def __init__(self):
        self.train_losses = []
        self.val_losses = []
        self.train_accs = []
        self.val_accs = []
        plt.ion()  # Interactive mode on
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    def update(self, train_loss, val_loss, train_acc, val_acc):
        """Update the plots with new metrics"""
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.train_accs.append(train_acc)
        self.val_accs.append(val_acc)
        
        # Clear previous plots
        self.ax1.clear()
        self.ax2.clear()
        
        # Plot loss
        self.ax1.plot(self.train_losses, label='Train Loss', color='blue')
        self.ax1.plot(self.val_losses, label='Val Loss', color='orange')
        self.ax1.set_title('Training & Validation Loss')
        self.ax1.set_xlabel('Epoch')
        self.ax1.set_ylabel('Loss')
        self.ax1.legend()
        self.ax1.grid(True)
        
        # Plot accuracy
        self.ax2.plot(self.train_accs, label='Train Accuracy', color='green')
        self.ax2.plot(self.val_accs, label='Val Accuracy', color='red')
        self.ax2.set_title('Training & Validation Accuracy')
        self.ax2.set_xlabel('Epoch')
        self.ax2.set_ylabel('Accuracy (%)')
        self.ax2.legend()
        self.ax2.grid(True)
        
        plt.tight_layout()
        plt.pause(0.1)  # Pause to update the display
    
    def save(self, save_dir='results'):
        """Save the final plots to disk"""
        os.makedirs(save_dir, exist_ok=True)
        
        # Save as PNG
        plt.savefig(os.path.join(save_dir, 'training_curves.png'))
        
        # Save data for later analysis
        torch.save({
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'train_accs': self.train_accs,
            'val_accs': self.val_accs
        }, os.path.join(save_dir, 'training_history.pt'))
        
        plt.close()

In [99]:

# 1. Fixed Dataset Class
class ScriptDataset(Dataset):
    def __init__(self, root_dir, transform=None, augment=True):
        self.root_dir = root_dir
        self.transform = transform
        self.augment = augment
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.samples = self._make_dataset()
        
    def _make_dataset(self):
        samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                samples.append((img_path, self.class_to_idx[class_name]))
        return samples
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('L')  # Grayscale
        img = transforms.Resize((28, 28))(img)
        
        if self.augment:
            img = transforms.RandomAffine(
                degrees=15, translate=(0.1, 0.1), scale=(0.8, 1.2)
            )(img)
            img = transforms.RandomPerspective(distortion_scale=0.2, p=0.5)(img)
        
        if self.transform:
            img = self.transform(img)
            
        return img, label
    
    def __len__(self):
        return len(self.samples)

# 2. Fixed Data Splitting
def get_datasets(root_dir, val_ratio=0.2):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    full_dataset = ScriptDataset(root_dir, transform=transform, augment=True)
    
    # Create indices for splitting
    num_samples = len(full_dataset)
    indices = list(range(num_samples))
    split = int(np.floor(val_ratio * num_samples))
    
    # Shuffle indices
    np.random.shuffle(indices)
    
    # Split into train and val
    train_idx, val_idx = indices[split:], indices[:split]
    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)
    
    # Disable augmentation for validation
    full_dataset.augment = False  # Affects validation through Subset
    
    return train_dataset, val_dataset

# 3. Fixed Episode Iterator
class EpisodeIterator:
    def __init__(self, subset, n_way=5, k_shot=5, n_query=15):
        self.subset = subset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        
        # Get the original dataset from Subset
        self.dataset = subset.dataset
        self.indices = subset.indices
        
        # Create label to indices mapping
        self.label_to_indices = {}
        for idx in self.indices:
            _, label = self.dataset.samples[idx]
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
        
        self.available_classes = list(self.label_to_indices.keys())
    
    def __iter__(self):
        return self
    
    def __next__(self):
        # Select n_way random classes with enough samples
        selected_classes = random.sample(
            [cls for cls in self.available_classes 
             if len(self.label_to_indices[cls]) >= self.k_shot + self.n_query],
            self.n_way
        )
        
        support = []
        query = []
        
        for class_idx in selected_classes:
            # Get all samples for this class
            class_indices = self.label_to_indices[class_idx]
            selected = random.sample(class_indices, self.k_shot + self.n_query)
            
            # Add to support and query sets
            support.extend(selected[:self.k_shot])
            query.extend(selected[self.k_shot:])
        
        # Convert indices to actual data
        support_data = [self.dataset[i] for i in support]
        query_data = [self.dataset[i] for i in query]
        
        # Stack into tensors
        support_imgs = torch.stack([img for img, _ in support_data])
        query_imgs = torch.stack([img for img, _ in query_data])
        
        return support_imgs, query_imgs
    
    def __len__(self):
        return 100  # Number of episodes per epoch

def train_with_visualization(model, root_dir, epochs=50, n_way=5, k_shot=5, n_query=15, val_ratio=0.2):
    # Get datasets
    train_dataset, val_dataset = get_datasets(root_dir, val_ratio)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    
    visualizer = TrainingVisualizer()
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        train_loss, train_acc = 0, 0
        count = 0
        
        episode_iterator = EpisodeIterator(train_dataset, n_way, k_shot, n_query)
        for support, query in episode_iterator:
            support = support.to(device)
            query = query.to(device)
            
            # Forward pass
            logits = model(support, query)
            labels = torch.arange(n_way).repeat(n_query).to(device)
            
            # Compute loss and accuracy
            loss = F.cross_entropy(logits, labels)
            acc = (logits.argmax(dim=1) == labels).float().mean()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += acc.item()
            count += 1
        
        train_loss /= count
        train_acc = 100 * train_acc / count
        
        # Validation phase
        model.eval()
        val_loss, val_acc = 0, 0
        count = 0
        
        with torch.no_grad():
            episode_iterator = EpisodeIterator(val_dataset, n_way, k_shot, n_query)
            for support, query in episode_iterator:
                support = support.to(device)
                query = query.to(device)
                
                logits = model(support, query)
                labels = torch.arange(n_way).repeat(n_query).to(device)
                
                loss = F.cross_entropy(logits, labels)
                acc = (logits.argmax(dim=1) == labels).float().mean()
                
                val_loss += loss.item()
                val_acc += acc.item()
                count += 1
        
        val_loss /= count
        val_acc = 100 * val_acc / count
        
        # Update visualizer
        visualizer.update(train_loss, val_loss, train_acc, val_acc)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f'Epoch {epoch+1}/{epochs}: '
              f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | '
              f'Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%')
        
        scheduler.step()
    
    visualizer.save('results')
    return model

In [None]:
# Initialize your model
model = PrototypicalNetwork(encoder=encoder).to(device)

# Train with visualization
trained_model = train_with_visualization(
    model,
    root_dir="./dataset",
    epochs=100,
    n_way=5,
    k_shot=5,
    n_query=15,
    val_ratio=0.2
)