In [1]:
!pip install easyfsl
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average


Defaulting to user installation because normal site-packages is not writeable




In [2]:
import os
import random
from collections import defaultdict
from PIL import Image, ImageOps
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class PrototypicalOmniglotDataset(Dataset):
    def __init__(self, root, num_classes=1623, n_shot=5, n_query=10, transform=None):
        self.root = root
        self.num_classes = num_classes
        self.n_shot = n_shot
        self.n_query = n_query
        self.transform = transform
        self.samples_by_label = defaultdict(list)
        self.all_imgs = {}
        self.classes = []

        # Common image file extensions
        image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.webp']

        # Organize data into classes and image paths
        for alphabet in os.listdir(self.root):
            alphabet_path = os.path.join(self.root, alphabet)
            if not os.path.isdir(alphabet_path):
                continue

            for char_class in os.listdir(alphabet_path):
                char_class_path = os.path.join(alphabet_path, char_class)

                # Skip if the entry is not a directory
                if not os.path.isdir(char_class_path):
                    continue

                all_images = []
                for img_name in os.listdir(char_class_path):
                    img_path = os.path.join(char_class_path, img_name)

                    # Check if it's a file with a recognized image extension
                    if os.path.isfile(img_path) and any(img_name.lower().endswith(ext) for ext in image_extensions):
                        all_images.append(img_path)

                if all_images:
                    char_class_name = f"{alphabet}_{char_class}"
                    self.samples_by_label[char_class_name] = list(range(len(all_images)))
                    self.all_imgs[char_class_name] = all_images
                    self.classes.append(char_class_name)

    def transform_image(self, raw_img):
#         img = ImageOps.invert(raw_img)
#         if self.transform is not None:
        img = self.transform(raw_img)
        return img

    def __getitem__(self, index):
        selected_classes = random.sample(self.classes, self.num_classes)
        class_indices = [self.samples_by_label[each_cls] for each_cls in selected_classes]

        support_set = []
        query_set = []
        label_id = 0
        qs=[]
        for idx_set in class_indices:
            # Creating support set
            selected_support = random.sample(idx_set, self.n_shot)
            support_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('RGB')) for each in selected_support]

            # Creating query set
            selected_query = random.sample(idx_set, self.n_query)
            query_images = [self.transform_image(Image.open(self.all_imgs[selected_classes[label_id]][each]).convert('RGB')) for each in selected_query]
            support_set.append((support_images, [label_id for _ in range(self.n_shot)]))
            # print("aaaaaaaaaa",label_id)
            query_set.append((query_images,  [label_id for _ in range(self.n_query)]))

            label_id += 1
            # qs= [item for item in a for _ in range(self.n_query)]
        return support_set, query_set

    def __len__(self):
        return 10

# Set up the transformations
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
])

# Set up the SiameseOmniglotDataset
# Training == evaluation
root_path = './final_training_processed_data'
num_classes = 5  # Set the desired number of classes per episode
n_shot = 5
n_query = 10

PrototypicalOmniglotDatasetLoader = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)


In [None]:

transform = transforms.Compose([
    transforms.Resize((200, 200)),  # Resize to expected size
    transforms.RandomRotation(degrees=30),  # Random rotation to address orientation issues
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
    transforms.RandomVerticalFlip(p=0.5),    # Random vertical flip
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),  # Stronger color jitter to account for color variation
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random translation for slight shift
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization (adjust if needed)
])


# Set up the PrototypicalOmniglotDatasetLoader
root_path ='./final_training_processed_data'# '/path/to/omniglot_dataset'
num_classes = 3
n_shot = 2
n_query = 3

dataset = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have a PrototypicalNetwork model
# model = PrototypicalNetwork(num_classes=num_classes).to(device)

# Set the model to evaluation mode
# model.eval()

# Choose an episode to visualize
episode_index = 10  # Change this to the desired episode index

# Get support set and query set for the chosen episode
support_set, query_set = dataset[episode_index]

# Visualize support set
for i, (support_images, lbl) in enumerate(support_set):
    for j, img in enumerate(support_images):
        plt.subplot(len(support_set), len(support_images), i * len(support_images) + j + 1)
        plt.imshow(img[0], cmap='gray')  # Assuming the image is grayscale
        plt.axis('off')
        plt.title(f'Support {i + 1}-{j + 1}{lbl[j]}')

plt.suptitle('Support Set')
plt.show()

# Visualize query set
for i, (query_images, ll) in enumerate(query_set):
    for j, img in enumerate(query_images):
        plt.subplot(len(query_set), len(query_images), i * len(query_images) + j + 1)
        plt.imshow(img[0], cmap='gray')  # Assuming the image is grayscale
        plt.axis('off')
        plt.title(f'Query {i + 1}-{ll[j]}')

plt.suptitle('Query Set')
plt.show()


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import ViTImageProcessor, ViTModel
import torchvision.models as models

class ResNet18WithDropout(nn.Module):
    def __init__(self, pretrained=False, dr=0.5):
        super(ResNet18WithDropout, self).__init__()
        self.resnet18 = models.resnet18(pretrained=pretrained)
        self.dropout = nn.Dropout(dr)
        num_features = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_features, 1000)  # Example output size

    def forward(self, x):
        x = self.resnet18(x)
        x = self.dropout(x)
        return x

class PrototypicalNetworks3(nn.Module):
    def __init__(self, backbone: nn.Module, combined_dim: int, dropout_rate: float = 0.5):
        super(PrototypicalNetworks3, self).__init__()
        self.backbone = backbone
        self.vit_processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
        self.vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        
        # Freeze the weights of the ViT model
        self.freeze_vit()

        # FC layers with dropout
        self.fc = nn.Sequential(
            nn.Linear(combined_dim, 600),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(600, 500),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(500, 400)
        )

    def freeze_vit(self):
        for param in self.vit_model.parameters():
            param.requires_grad = False

    def rescale_to_vit_range(self, images: torch.Tensor) -> torch.Tensor:
        """
        Rescale image tensor to [0, 1] for ViT input.
        """
        images_rescaled = (images - images.min()) / (images.max() - images.min())
        return images_rescaled

    def forward(self, support_images, support_labels, query_images: torch.Tensor) -> torch.Tensor:
        """
        Predict query labels using labeled support images.
        """
        num_classes = len(support_labels)
        combined_prototypes = []

        for class_label in range(num_classes):
            support_images_eachclass = support_images[class_label]

            # Backbone features
            class_features_backbone = self.backbone(support_images_eachclass.to(device)).mean(dim=0)

            # ViT features
            support_images_vit = self.rescale_to_vit_range(support_images_eachclass)
            support_images_vit = self.vit_processor(images=support_images_vit, return_tensors="pt")['pixel_values']
            class_features_vit = self.vit_model(support_images_vit.to(device)).last_hidden_state.mean(dim=1).mean(dim=0)

            # Combine features
            combined_features = torch.cat((class_features_backbone, class_features_vit), dim=-1)
            combined_prototype = self.fc(combined_features)
            combined_prototypes.append(combined_prototype)
        
        combined_prototypes = torch.stack(combined_prototypes)

        # Backbone query features
        query_features_backbone = self.backbone(query_images)

        # ViT query features
        query_images_vit = self.rescale_to_vit_range(query_images)
        query_images_vit = self.vit_processor(images=query_images_vit, return_tensors="pt")['pixel_values']
        query_features_vit = self.vit_model(query_images_vit.to(device)).last_hidden_state.mean(dim=1).mean(dim=0)

        # Expand ViT features to match backbone features in terms of batch size
        query_features_vit = query_features_vit.unsqueeze(0).expand(query_features_backbone.shape[0], -1)

        # Combine query features
        combined_query_features = torch.cat((query_features_backbone, query_features_vit), dim=-1)
        combined_query_features = self.fc(combined_query_features)

        # Compute the distance between combined query features and combined prototypes
        distances = torch.cdist(combined_query_features, combined_prototypes)

        return -distances

# Instantiate the modified ResNet18 model with dropout
backbone = ResNet18WithDropout(pretrained=True, dr=0.3).to(device)

# Set combined_dim as the sum of the output dimensions from ResNet18 and ViT
combined_dim = 1000 + 768  # Ensure these dimensions match your model output sizes

# Instantiate the PrototypicalNetworks3 model

model = PrototypicalNetworks3(backbone=backbone, combined_dim=combined_dim,dropout_rate=0.3).to(device)

# Example input data
support_images = torch.randn(3, 4, 3, 200, 200).to(device)  # 3 classes, 4 support samples per class, 3 channels, 200x200 images
support_labels = torch.tensor([0, 1, 2]).to(device)
query_images = torch.randn(4, 3, 200, 200).to(device)  # 4 query samples, 3 channels, 200x200 images

# Forward pass
distances = model(support_images, support_labels, query_images)

# Print total and learnable parameter counts
def print_parameter_counts(model):
    total_params = sum(p.numel() for p in model.parameters())
    learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")
    print(f"Learnable parameters: {learnable_params}")

print_parameter_counts(model)

# Configure optimizer with weight decay for L2 regularization
optimizer = optim.Adam(
    [
        {'params': model.backbone.parameters(), 'weight_decay': 1e-4},  # Apply weight decay to backbone
        {'params': model.fc.parameters(), 'weight_decay': 1e-4}  # Apply weight decay to FC layers
    ],
    lr=1e-3,
    weight_decay=1e-4 # Adjust weight decay here if needed
)
print(distances.shape)

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Total parameters: 99641060
Learnable parameters: 13251812
torch.Size([4, 3])


In [None]:




        
        
transform = transforms.Compose([
    transforms.Resize((200, 200)),  # Resize to the size expected by the model (e.g., 224x224)
    transforms.RandomHorizontalFlip(),  # Random horizontal flip for augmentation
    transforms.RandomVerticalFlip(),    # Random vertical flip for augmentation
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color jitter for augmentation
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize with ImageNet statistics (adjust if needed)
])
        
        
# 
#background == Training

root_path = './final_training_processed_data'
num_classes = 5
n_shot = 10
n_query = 10
# generate_random_params_trainloader()
print(f"Generated random values - num_classes: {num_classes}, n_shot: {n_shot}, n_query: {n_query}")

optimizer = optim.Adam(
    [
        {'params': model.backbone.parameters(), 'weight_decay': 1e-2},  # Apply weight decay to backbone
        {'params': model.fc.parameters(), 'weight_decay': 1e-2}  # Apply weight decay to FC layers
    ],
    lr=0.0001,
    weight_decay=1e-2 # Adjust weight decay here if needed
)
print(distances.shape)


criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.004)
episode_losses = []

for episode in range(500):

    support_set, query_set = dataset[episode]

    # Convert data to tensors and move to device
    support_set = [(torch.stack(images).to(device), label[0]) for images, label in support_set]
    query_set = [(torch.stack(images).to(device), label) for images, label in query_set]

    # Forward pass
    support_images, support_labels = zip(*support_set)
    query_images, query_labels = zip(*query_set)
    # print(len(support_images),len(support_labels),type(support_images[0]),support_labels,)

    # support_images = torch.cat(support_images)
    # support_labels = torch.tensor(support_labels).to(device)
    query_images = torch.cat(query_images)
    query_labels = torch.tensor(query_labels).to(device)
    query_labels = query_labels.view(-1)
    tskloss = 0
    ta=""
    # print(query_labels)
    for i in range(40):
      model.train()

      classification_scores = model(support_images, support_labels, query_images)
      # print(classification_scores.shape)


      loss = criterion(classification_scores, query_labels)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      tskloss+=loss.item()
      if(i==0 or i==19 or i==39):
        predicted_labels = torch.argmax(classification_scores, dim=1)
        pp=predicted_labels.tolist()
        act=query_labels.tolist()
        cortt=0
        for iii in range(len(pp)):
          if pp[iii] == act[iii]:
              cortt = cortt + 1
        print("Traning accuracy: ",cortt/len(pp))
        ta="Training Accuracy: ,"+str(cortt/len(pp))
      print("    --->>>   episode",episode,"mini_epoch:",i,"  train loss: ",loss.item())

# Calculate the total number of matching values
# total_matching_values = sum(matching_counts)

    episode_losses.append(tskloss)
    num_classes = random.choice([5,8])  # Random number of classes (e.g., between 3 and 10)
    n_shot = random.randint(1, 5)       # Random number of support examples per class (e.g., between 5 and 15)
    n_query = 15-n_shot
    dataset = PrototypicalOmniglotDataset(root=root_path, num_classes=num_classes, n_shot=n_shot, n_query=n_query, transform=transform)
    print(f"Generated  Tasks - num_classes: {num_classes}, n_shot: {n_shot}, n_query: {n_query}")

    if episode%5==0:
      torch.save(model.state_dict(), './cGVit_res18__1'+'/G_vit_r189____'+str(episode)+'.pth')

Generated random values - num_classes: 5, n_shot: 10, n_query: 10
torch.Size([4, 3])
Traning accuracy:  0.1
    --->>>   episode 0 mini_epoch: 0   train loss:  2.180184841156006
    --->>>   episode 0 mini_epoch: 1   train loss:  1.8622276782989502
    --->>>   episode 0 mini_epoch: 2   train loss:  1.6817309856414795
    --->>>   episode 0 mini_epoch: 3   train loss:  1.5386207103729248
    --->>>   episode 0 mini_epoch: 4   train loss:  1.4209096431732178
    --->>>   episode 0 mini_epoch: 5   train loss:  1.3588522672653198
    --->>>   episode 0 mini_epoch: 6   train loss:  1.336037278175354
    --->>>   episode 0 mini_epoch: 7   train loss:  1.4053798913955688
    --->>>   episode 0 mini_epoch: 8   train loss:  1.271716594696045
    --->>>   episode 0 mini_epoch: 9   train loss:  1.2621400356292725
    --->>>   episode 0 mini_epoch: 10   train loss:  1.2799463272094727
    --->>>   episode 0 mini_epoch: 11   train loss:  1.2650247812271118
    --->>>   episode 0 mini_epoch: 12   t

    --->>>   episode 2 mini_epoch: 29   train loss:  0.36127176880836487
    --->>>   episode 2 mini_epoch: 30   train loss:  0.42144450545310974
    --->>>   episode 2 mini_epoch: 31   train loss:  0.4022805988788605
    --->>>   episode 2 mini_epoch: 32   train loss:  0.373588502407074
    --->>>   episode 2 mini_epoch: 33   train loss:  0.35719144344329834
    --->>>   episode 2 mini_epoch: 34   train loss:  0.35926347970962524
    --->>>   episode 2 mini_epoch: 35   train loss:  0.3328379690647125
    --->>>   episode 2 mini_epoch: 36   train loss:  0.3043932616710663
    --->>>   episode 2 mini_epoch: 37   train loss:  0.31544020771980286
    --->>>   episode 2 mini_epoch: 38   train loss:  0.35348501801490784
Traning accuracy:  0.9272727272727272
    --->>>   episode 2 mini_epoch: 39   train loss:  0.29855120182037354
Generated  Tasks - num_classes: 8, n_shot: 1, n_query: 14
Traning accuracy:  0.09821428571428571
    --->>>   episode 3 mini_epoch: 0   train loss:  4.8732810020446

    --->>>   episode 5 mini_epoch: 18   train loss:  0.8677341938018799
Traning accuracy:  0.6979166666666666
    --->>>   episode 5 mini_epoch: 19   train loss:  0.7879533767700195
    --->>>   episode 5 mini_epoch: 20   train loss:  0.7476314902305603
    --->>>   episode 5 mini_epoch: 21   train loss:  0.7391100525856018
    --->>>   episode 5 mini_epoch: 22   train loss:  0.715923547744751
    --->>>   episode 5 mini_epoch: 23   train loss:  0.6906638741493225
    --->>>   episode 5 mini_epoch: 24   train loss:  0.7409502863883972
    --->>>   episode 5 mini_epoch: 25   train loss:  0.6409460306167603
    --->>>   episode 5 mini_epoch: 26   train loss:  0.7244570255279541
    --->>>   episode 5 mini_epoch: 27   train loss:  0.6112342476844788
    --->>>   episode 5 mini_epoch: 28   train loss:  0.7635172009468079
    --->>>   episode 5 mini_epoch: 29   train loss:  0.5276405811309814
    --->>>   episode 5 mini_epoch: 30   train loss:  0.6736646294593811
    --->>>   episode 5 mini

    --->>>   episode 8 mini_epoch: 5   train loss:  1.7273598909378052
    --->>>   episode 8 mini_epoch: 6   train loss:  1.7327370643615723
    --->>>   episode 8 mini_epoch: 7   train loss:  1.5950970649719238
    --->>>   episode 8 mini_epoch: 8   train loss:  1.8308639526367188
    --->>>   episode 8 mini_epoch: 9   train loss:  1.757757306098938
    --->>>   episode 8 mini_epoch: 10   train loss:  1.4447715282440186
    --->>>   episode 8 mini_epoch: 11   train loss:  1.4456099271774292
    --->>>   episode 8 mini_epoch: 12   train loss:  1.3635109663009644
    --->>>   episode 8 mini_epoch: 13   train loss:  1.3610376119613647
    --->>>   episode 8 mini_epoch: 14   train loss:  1.2911171913146973
    --->>>   episode 8 mini_epoch: 15   train loss:  1.2931281328201294
    --->>>   episode 8 mini_epoch: 16   train loss:  1.325044870376587
    --->>>   episode 8 mini_epoch: 17   train loss:  1.238075613975525
    --->>>   episode 8 mini_epoch: 18   train loss:  1.0488547086715698


    --->>>   episode 10 mini_epoch: 35   train loss:  0.30032244324684143
    --->>>   episode 10 mini_epoch: 36   train loss:  0.33275482058525085
    --->>>   episode 10 mini_epoch: 37   train loss:  0.3485969007015228
    --->>>   episode 10 mini_epoch: 38   train loss:  0.29090234637260437
Traning accuracy:  0.7571428571428571
    --->>>   episode 10 mini_epoch: 39   train loss:  0.3440403938293457
Generated  Tasks - num_classes: 5, n_shot: 5, n_query: 10
Traning accuracy:  0.24
    --->>>   episode 11 mini_epoch: 0   train loss:  1.835506796836853
    --->>>   episode 11 mini_epoch: 1   train loss:  1.5776230096817017
    --->>>   episode 11 mini_epoch: 2   train loss:  1.2889492511749268
    --->>>   episode 11 mini_epoch: 3   train loss:  1.4612513780593872
    --->>>   episode 11 mini_epoch: 4   train loss:  1.5753531455993652
    --->>>   episode 11 mini_epoch: 5   train loss:  1.4147480726242065
    --->>>   episode 11 mini_epoch: 6   train loss:  1.1955245733261108
    --->>

    --->>>   episode 13 mini_epoch: 22   train loss:  0.2596912384033203
    --->>>   episode 13 mini_epoch: 23   train loss:  0.25467434525489807
    --->>>   episode 13 mini_epoch: 24   train loss:  0.2425687164068222
    --->>>   episode 13 mini_epoch: 25   train loss:  0.13799671828746796
    --->>>   episode 13 mini_epoch: 26   train loss:  0.23667804896831512
    --->>>   episode 13 mini_epoch: 27   train loss:  0.1668073982000351
    --->>>   episode 13 mini_epoch: 28   train loss:  0.06472112983465195
    --->>>   episode 13 mini_epoch: 29   train loss:  0.15049460530281067
    --->>>   episode 13 mini_epoch: 30   train loss:  0.03886856138706207
    --->>>   episode 13 mini_epoch: 31   train loss:  0.11157850921154022
    --->>>   episode 13 mini_epoch: 32   train loss:  0.04192839562892914
    --->>>   episode 13 mini_epoch: 33   train loss:  0.10911959409713745
    --->>>   episode 13 mini_epoch: 34   train loss:  0.05840284749865532
    --->>>   episode 13 mini_epoch: 35   