<a href="https://colab.research.google.com/github/Paul-Steve-Mithun/FSL_AUTONOMOUS_DRIVING/blob/main/FSL_TRAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch-summary torch-lr-finder timm easyfsl
!pip install torch torchvision torchaudio

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Collecting torch-lr-finder
  Downloading torch_lr_finder-0.2.1-py3-none-any.whl (11 kB)
Collecting timm
  Downloading timm-1.0.7-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting easyfsl
  Downloading easyfsl-1.5.0-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.8/72.8 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=0.4.1->torch-lr-finder)
  Using cached nvidia_c

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.models import resnet50
from torchsummary import summary
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

from google.colab import drive
drive.mount('/content/drive')

# Set random seed for reproducibility
torch.manual_seed(0)

# Set the path to your dataset
data_path = '/content/drive/MyDrive/datasets/Steve_Dataset'  # Adjust this path

# Check the classes in the dataset
class_name = os.listdir(data_path)
print(class_name)


Mounted at /content/drive
['Car', 'Truck', 'Bike', 'Pedestrians']


In [None]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone

    def forward(self, support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor) -> torch.Tensor:
        z_support = self.backbone.forward(support_images)
        z_query = self.backbone.forward(query_images)
        n_way = len(torch.unique(support_labels))
        z_proto = torch.cat([z_support[torch.nonzero(support_labels == label)].mean(0) for label in range(n_way)])
        dists = torch.cdist(z_query, z_proto)
        scores = -dists
        return scores

# Use ResNet-50 as the backbone
convolutional_network = resnet50(pretrained=True)
convolutional_network.fc = nn.Flatten()
model = PrototypicalNetworks(convolutional_network)
print(model)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 129MB/s]


PrototypicalNetworks(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequenti

In [None]:
from PIL import Image
import os
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []

        self.class_names = sorted(os.listdir(data_path))
        for label, class_name in enumerate(self.class_names):
            current_folder = os.path.join(data_path, class_name)
            for i, file in enumerate(os.listdir(current_folder)):
                fullpath = os.path.join(current_folder, file)

                # Check if the image can be opened and loaded successfully
                try:
                    # Attempt to open and load the image
                    with Image.open(fullpath) as im:
                        im.load()  # Force loading the image data

                    self.images.append(fullpath)
                    self.labels.append(label)
                except IOError:
                    print(f"Skipping corrupted image: {fullpath}")  # Log the skipped image

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        # Open image within a try-except block to catch potential errors
        try:
            image = Image.open(img_path)
        except IOError:
            print(f"Error opening image: {img_path}")
            return None, None  # Return None for both image and label if error

        # Check if image is palette with transparency
        if image.mode == 'P' and 'transparency' in image.info:
            image = image.convert('RGBA')

        image = image.convert('RGB')  # Ensure RGB format for consistency

        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

N_WAY = 4  # Number of classes in a task
N_SHOT = 15  # Number of images per class in the support set
N_QUERY = 1  # Number of images per class in the query set
N_EVALUATION_TASKS = 50

test_set = CustomDataset(data_path, transform=transform)
test_set.get_labels = lambda: [instance[1] for instance in zip(test_set.images, test_set.labels)]

test_sampler = TaskSampler(
    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=4,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)


In [None]:
# Function to evaluate one task
def evaluate_on_one_task(support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor, query_labels: torch.Tensor) -> [int, int]:
    support_images = support_images.to(device)
    support_labels = support_labels.to(device)
    query_images = query_images.to(device)
    query_labels = query_labels.to(device)

    with torch.no_grad():
        output = model(support_images, support_labels, query_images)
    predicted_labels = torch.max(output, 1)[1]
    correct_predictions = (predicted_labels == query_labels).sum().item()
    return correct_predictions, len(query_labels)

# Function to evaluate the entire dataset
def evaluate(data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0

    model.eval()
    with torch.no_grad():
        for episode_index, (support_images, support_labels, query_images, query_labels, _) in tqdm(enumerate(data_loader), total=len(data_loader)):
            correct, total = evaluate_on_one_task(support_images, support_labels, query_images, query_labels)
            total_predictions += total
            correct_predictions += correct

    accuracy = 100 * correct_predictions / total_predictions
    print(f"Model tested on {len(data_loader)} tasks. Accuracy: {accuracy:.2f}%")
    return accuracy

# Load the checkpoint
def load_checkpoint(file_path, model):
    checkpoint = torch.load(file_path, map_location=torch.device('cpu'))  # Use appropriate device if not using CPU
    # Handle missing keys
    missing_keys, unexpected_keys = model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if missing_keys:
        print("Warning: Missing keys in state_dict:", missing_keys)
    if unexpected_keys:
        print("Warning: Unexpected keys in state_dict:", unexpected_keys)

    epoch = checkpoint['epoch']
    mean_accuracy = checkpoint['mean_accuracy']
    std_accuracy = checkpoint['std_accuracy']
    print(f"Checkpoint loaded from epoch {epoch} with mean accuracy: {mean_accuracy:.2f}%")
    return epoch, mean_accuracy, std_accuracy


In [None]:
# Set device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Path to your saved checkpoint file
save_path = '/content/drive/MyDrive/prototypical_networks_evalmodel1.pth'

# Example usage to load the checkpoint
loaded_epoch, loaded_mean_accuracy, loaded_std_accuracy = load_checkpoint(save_path, model)

# Now you can continue evaluating your model or resume training from this point


Checkpoint loaded from epoch 10 with mean accuracy: 92.80%


In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader):
        support_images, support_labels, query_images, query_labels, _ = batch

        support_images = support_images.to(device)
        support_labels = support_labels.to(device)
        query_images = query_images.to(device)
        query_labels = query_labels.to(device)

        optimizer.zero_grad()

        output = model(support_images, support_labels, query_images)
        loss = criterion(output, query_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for param in convolutional_network.parameters():
    param.requires_grad = True  # Unfreeze all layers for fine-tuning

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    avg_loss = train(model, test_loader, optimizer, criterion, device)
    print(f"Epoch [{epoch+1}/10], Loss: {avg_loss:.4f}")
    scheduler.step()
    evaluate(test_loader)


  2%|▏         | 1/50 [01:55<1:34:18, 115.49s/it]

In [None]:
save_path = '/content/drive/MyDrive/FSL_FinalModel.pth'  # Adjust the path as needed

# Function to save the checkpoint
def save_checkpoint(model, optimizer, epoch, mean_accuracy, std_accuracy, file_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'mean_accuracy': mean_accuracy,
        'std_accuracy': std_accuracy
    }
    torch.save(checkpoint, file_path)
    print(f"Checkpoint saved at epoch {epoch} with mean accuracy: {mean_accuracy:.2f}%")

# Assuming `mean_accuracy` and `std_accuracy` are calculated during evaluation
# Evaluate the model to get accuracy
mean_accuracy = evaluate(test_loader)
std_accuracy = 0  # If you do not calculate std accuracy, set it to 0 or appropriate value

# Save the checkpoint
save_checkpoint(model, optimizer, epoch=20, mean_accuracy=mean_accuracy, std_accuracy=std_accuracy, file_path=save_path)

100%|██████████| 100/100 [02:37<00:00,  1.57s/it]


Model tested on 100 tasks. Accuracy: 96.75%
Checkpoint saved at epoch 20 with mean accuracy: 96.75%
