In [1]:
import copy
from pathlib import Path
import random
from statistics import mean
import numpy as np
import torch
from torch import nn
from tqdm import tqdm

In [2]:
!pip install easyfsl

Collecting easyfsl
  Downloading easyfsl-1.5.0-py3-none-any.whl.metadata (16 kB)
Downloading easyfsl-1.5.0-py3-none-any.whl (72 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.8/72.8 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: easyfsl
Successfully installed easyfsl-1.5.0


In [3]:
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split

In [4]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [5]:
n_way = 5
n_shot = 5
n_query = 10

DEVICE = "cuda"
n_workers = 12
n_tasks_per_epoch = 500
n_validation_tasks = 100

In [6]:
train_transform = transforms.Compose([
    
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=45),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),  
    transforms.ToTensor(),
    
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

In [7]:
test_data_path = "/kaggle/input/bm-dataset/BM_Test"
val_data_path = "/kaggle/input/bm-dataset/BM_Val"
final_data_path = "/kaggle/input/bm-dataset/BM_Final"
val_set = ImageFolder(root=val_data_path, transform=test_transform)
val_set.get_labels = lambda: [instance[1] for instance in val_set]
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)




In [8]:
from abc import abstractmethod
from typing import Optional
import torch
from torch import Tensor, nn
from easyfsl.methods.utils import compute_prototypes

class MetaClassifier(nn.Module):

    def __init__(
        self,
        backbone: Optional[nn.Module] = None,
        use_softmax: bool = False,
        feature_centering: Optional[Tensor] = None,
        feature_normalization: Optional[float] = None,
        
    ):
        
        super().__init__()

        self.backbone = backbone if backbone is not None else nn.Identity()
        self.use_softmax = use_softmax

        self.prototypes = torch.tensor(())
        self.support_features = torch.tensor(())
        self.support_labels = torch.tensor(())

        self.feature_centering = (
            feature_centering if feature_centering is not None else torch.tensor(0)
        )
        self.feature_normalization = feature_normalization

    @abstractmethod
    def forward(
        self,
        query_images: Tensor,
    ) -> Tensor:
       
        raise NotImplementedError(
            
        )
    def compute_prototypes(support_features: Tensor, support_labels: Tensor) -> Tensor:
    

        n_way = len(torch.unique(support_labels))
    
        return torch.cat(
            [
                support_features[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

    def process_support_set(
        self,
        support_images: Tensor,
        support_labels: Tensor,
    ):
        
        self.compute_prototypes_and_store_support_set(support_images, support_labels)

    @staticmethod
    def is_transductive() -> bool:
        raise NotImplementedError(
            "All few-shot algorithms must implement a is_transductive method."
        )
    def compute_prototypes_and_store_support_set(
        self,
        support_images: Tensor,
        support_labels: Tensor,
    ):
        self.support_labels = support_labels
        self.support_features = self.compute_features(support_images)
        self._raise_error_if_features_are_multi_dimensional(self.support_features)
        self.prototypes = compute_prototypes(self.support_features, support_labels)
    def compute_features(self, images: Tensor) -> Tensor:
        
        original_features = self.backbone(images)
        centered_features = original_features - self.feature_centering
        if self.feature_normalization is not None:
            return nn.functional.normalize(
                centered_features, p=self.feature_normalization, dim=1
            )
        return centered_features
    
    def softmax_if_specified(self, output: Tensor, temperature: float = 1.0) -> Tensor:
        
        return (temperature * output).softmax(-1) if self.use_softmax else output
    def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
        
        return -torch.cdist(samples, self.prototypes)
    @staticmethod
    def _raise_error_if_features_are_multi_dimensional(features: Tensor):
        if len(features.shape) != 2:
            raise ValueError(
                "Illegal backbone or feature shape. "
                "Expected output for an image is a 1-dim tensor."
            )


In [9]:
class PrototypicalNetworks(MetaClassifier):
    def __init__(
        self,
        backbone: Optional[nn.Module] = None,
        use_softmax: bool = False,
        feature_centering: Optional[Tensor] = None,
        feature_normalization: Optional[float] = None,
    ):
        super().__init__(backbone, use_softmax, feature_centering, feature_normalization)
        self.train_loader = None  # Initialize train_loader attribute

    def forward(self, query_images: Tensor) -> Tensor:
        query_features = self.compute_features(query_images)
        self._raise_error_if_features_are_multi_dimensional(query_features)
        scores = self.l2_distance_to_prototypes(query_features)
        return self.softmax_if_specified(scores)

    @staticmethod
    def is_transductive() -> bool:
        return False

    def get_data_loader(self):
        return self.train_loader


In [10]:
from torch.optim import SGD, Optimizer
def training_epoch(
    model: MetaClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()
            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)


In [11]:
from typing import List, Optional, Tuple
from sklearn.metrics import f1_score, precision_score, recall_score
def evaluate_on_one_task(
    model: FewShotClassifier,
    support_images: Tensor,
    support_labels: Tensor,
    query_images: Tensor,
    query_labels: Tensor,
) -> Tuple[int, int]:
    
    model.process_support_set(support_images, support_labels)
    predictions = model(query_images).detach().data
    number_of_correct_predictions = int(
        (torch.max(predictions, 1)[1] == query_labels).sum().item()
    )
    return number_of_correct_predictions, len(query_labels)
def evaluate(
    model: FewShotClassifier,
    data_loader: DataLoader,
    device: str = "cuda",
    use_tqdm: bool = True,
    tqdm_prefix: Optional[str] = None,
) -> Tuple[float, float, float, float]:
    total_predictions = 0
    correct_predictions = 0
    all_predictions = []
    all_targets = []

    model.eval()
    with torch.no_grad():
        with tqdm(
            enumerate(data_loader),
            total=len(data_loader),
            disable=not use_tqdm,
            desc=tqdm_prefix,
        ) as tqdm_eval:
            for _, (
                support_images,
                support_labels,
                query_images,
                query_labels,
                _,
            ) in tqdm_eval:
                correct, total = evaluate_on_one_task(
                    model,
                    support_images.to(device),
                    support_labels.to(device),
                    query_images.to(device),
                    query_labels.to(device),
                )

                total_predictions += total
                correct_predictions += correct

                all_predictions.extend(model(query_images).detach().cpu().numpy().argmax(axis=1))
                all_targets.extend(query_labels.cpu().numpy())

                tqdm_eval.set_postfix(accuracy=correct_predictions / total_predictions)

    accuracy = correct_predictions / total_predictions
    precision = precision_score(all_targets, all_predictions, average='weighted')
    recall = recall_score(all_targets, all_predictions, average='weighted')
    f1 = f1_score(all_targets, all_predictions, average='weighted')

    return accuracy, precision, recall, f1

In [12]:
def create_clients(train_paths, few_shot_classifier, transform, n_way, n_shot, n_query, n_tasks_per_epoch, n_workers, optimizer, loss_function):
    clients = []
    for i, path in enumerate(train_paths):
        client_name = 'client_' + str(i)
        
        # Load data
        train_set = ImageFolder(root=path, transform=transform)
        train_set.get_labels = lambda: [instance[1] for instance in train_set]
        
        # Define the task sampler
        train_sampler = TaskSampler(train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch)
        
        # Create DataLoader with the task sampler
        train_loader = DataLoader(
            train_set,
            batch_sampler=train_sampler,
            num_workers=n_workers,
            pin_memory=True,
            collate_fn=train_sampler.episodic_collate_fn,
        )
        
        
        few_shot_classifier.train_loader = train_loader
        
        clients.append((client_name, few_shot_classifier))
    
    return clients

# Define your paths for different subsets of data
train_data_paths = ["/kaggle/input/bm-train/BM_Train1",
                    "/kaggle/input/bm-train/BM_Train2",
                    "/kaggle/input/bm-train/BM_Train3",
                    "/kaggle/input/bm-train/BM_Train4"]

In [13]:
from torchvision.models import resnet18

convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
convolutional_network = nn.DataParallel(convolutional_network)
#print(convolutional_network)
meta_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 118MB/s] 


In [14]:

from torch.optim.lr_scheduler import MultiStepLR



LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 1
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2


train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)



In [15]:
clients = create_clients(train_data_paths, few_shot_classifier, train_transform, n_way, n_shot, n_query, n_tasks_per_epoch, n_workers, train_optimizer, LOSS_FUNCTION)



In [16]:
n_rounds = 3
n_epochs_per_round = 1
best_validation_accuracy = 0.0
best_state = None

for round in range(n_rounds):
    print(f"Round {round + 1}/{n_rounds}")
    
    
    client_weights = {}
    for client_name, client_model in clients:
        print(f"Training {client_name}")
        
        
        client_model = PrototypicalNetworks(convolutional_network).to(DEVICE)
        client_model.train_loader = meta_classifier.train_loader  
        
        train_loader = client_model.train_loader
        for _ in range(n_epochs_per_round):
            average_loss = training_epoch(client_model, train_loader, train_optimizer)
        
        client_weights[client_name] = copy.deepcopy(client_model.state_dict())

    
    weighted_average_weights = {}
    total_clients = len(clients)
    for name, weights in client_weights.items():
        for param_name, param in weights.items():
            if name == 'client_0': 
                weighted_average_weights[param_name] = param / total_clients
            else:
                weighted_average_weights[param_name] += param / total_clients

    meta_classifier.load_state_dict(weighted_average_weights)

    # Evaluate the updated central model on the validation set
    validation_accuracy, precision, recall, f1 = evaluate(few_shot_classifier, val_loader, DEVICE, tqdm_prefix="Validation")

    # Save the best model if validation accuracy improves
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        torch.save(best_state, "best_model.pth")
        print("New best model saved!")

    # Step the scheduler
    train_scheduler.step()

# Load the best model state
few_shot_classifier.load_state_dict(best_state)


Round 1/3
Training client_0


Training: 100%|██████████| 500/500 [02:09<00:00,  3.87it/s, loss=0.474]

Training client_1



Training: 100%|██████████| 500/500 [02:04<00:00,  4.00it/s, loss=0.0897]

Training client_2



Training: 100%|██████████| 500/500 [02:05<00:00,  3.98it/s, loss=0.0226]

Training client_3



Training: 100%|██████████| 500/500 [02:04<00:00,  4.01it/s, loss=0.00937]
Validation: 100%|██████████| 100/100 [00:19<00:00,  5.02it/s, accuracy=0.706]


New best model saved!
Round 2/3
Training client_0


Training: 100%|██████████| 500/500 [02:05<00:00,  3.99it/s, loss=0.0136]

Training client_1



Training: 100%|██████████| 500/500 [02:04<00:00,  4.01it/s, loss=0.00639]

Training client_2



Training: 100%|██████████| 500/500 [02:04<00:00,  4.01it/s, loss=0.0036] 

Training client_3



Training: 100%|██████████| 500/500 [02:04<00:00,  4.02it/s, loss=0.00191]
Validation: 100%|██████████| 100/100 [00:19<00:00,  5.06it/s, accuracy=0.677]

Round 3/3
Training client_0



Training: 100%|██████████| 500/500 [02:04<00:00,  4.02it/s, loss=0.00195]

Training client_1



Training: 100%|██████████| 500/500 [02:04<00:00,  4.01it/s, loss=0.00182]

Training client_2



Training: 100%|██████████| 500/500 [02:04<00:00,  4.01it/s, loss=0.00201]

Training client_3



Training: 100%|██████████| 500/500 [02:04<00:00,  4.00it/s, loss=0.0011] 
Validation: 100%|██████████| 100/100 [00:19<00:00,  5.05it/s, accuracy=0.657]


<All keys matched successfully>

In [None]:
import copy

n_rounds = 3
n_epochs_per_round = 1
best_validation_accuracy = 0.0
best_state = None


few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)

for round in range(n_rounds):
    print(f"Round {round + 1}/{n_rounds}")
    
    # Train each client individually for one epoch on the current global model
    client_weights = {}
    for client_name, client_model in clients:
        print(f"Training {client_name} on global model")
        
        
        client_model.load_state_dict(copy.deepcopy(few_shot_classifier.state_dict()))
        
        # Train the client model for one epoch
        for _ in range(n_epochs_per_round):
            average_loss = training_epoch(client_model, client_model.train_loader, train_optimizer)
        
        # Save the trained weights of the client model
        client_weights[client_name] = copy.deepcopy(client_model.state_dict())

    # Perform weighted averaging of the weights
    weighted_average_weights = {}
    total_clients = len(clients)
    for name, weights in client_weights.items():
        for param_name, param in weights.items():
            if name == 'client_0':  # For the first client, initialize the weighted average weights
                weighted_average_weights[param_name] = param / total_clients
            else:
                weighted_average_weights[param_name] += param / total_clients

    # Update the global model with the weighted average weights
    few_shot_classifier.load_state_dict(weighted_average_weights)

    # Evaluate the updated global model on the validation set
    validation_accuracy, precision, recall, f1 = evaluate(few_shot_classifier, val_loader, DEVICE, tqdm_prefix="Validation")

    # Save the best model if validation accuracy improves
    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(few_shot_classifier.state_dict())
        torch.save(best_state, "best_model.pth")
        print("New best model saved!")

    # Step the scheduler
    train_scheduler.step()

# Load the best model state
few_shot_classifier.load_state_dict(best_state)


In [None]:
best_state = torch.load("best_model.pth")

In [18]:
n_way = 3
n_shot = 3
n_query = 7

In [19]:
n_tasks_per_epoch = 10
n_validation_tasks = 100
n_test_tasks = 100
n_epochs = 20
test_set = ImageFolder(root=test_data_path, transform=train_transform)
final_set = ImageFolder(root=final_data_path, transform=test_transform)
test_set.get_labels = lambda: [instance[1] for instance in test_set]
final_set.get_labels = lambda: [instance[1] for instance in final_set]


In [20]:
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
final_sampler = TaskSampler(
    final_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)
final_loader = DataLoader(
    final_set,
    batch_sampler=final_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=final_sampler.episodic_collate_fn,
)



In [21]:
new_model = resnet18(weights=None).to(DEVICE)
new_model.fc = nn.Flatten()
new_model = nn.DataParallel(new_model)
#print(new_model)
new_model = PrototypicalNetworks(convolutional_network).to(DEVICE)
new_model.load_state_dict(best_state)

<All keys matched successfully>

In [22]:
test_optimizer = SGD(
    new_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
test_scheduler = MultiStepLR(
    test_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

In [23]:
best_final_accuracy = 0.0
final_state = new_model.state_dict()
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(new_model, test_loader, test_optimizer)
    final_accuracy, precision, recall, f1 = evaluate(
        new_model, final_loader, device=DEVICE, tqdm_prefix="Validation"
    )
    if final_accuracy > best_final_accuracy:
        best_final_accuracy = final_accuracy
        print("new best model!")
        final_state = copy.deepcopy(new_model.state_dict())
        torch.save(final_state, "final_model.pth")
    train_scheduler.step()

Epoch 0


Training: 100%|██████████| 10/10 [00:02<00:00,  4.39it/s, loss=0.512]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.92it/s, accuracy=0.946]


new best model!
Epoch 1


Training: 100%|██████████| 10/10 [00:02<00:00,  4.57it/s, loss=0.361]
Validation: 100%|██████████| 100/100 [00:15<00:00,  6.56it/s, accuracy=0.872]

Epoch 2



Training: 100%|██████████| 10/10 [00:02<00:00,  4.70it/s, loss=0.281]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.80it/s, accuracy=0.919]

Epoch 3



Training: 100%|██████████| 10/10 [00:02<00:00,  4.63it/s, loss=0.19] 
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.68it/s, accuracy=0.973]


new best model!
Epoch 4


Training: 100%|██████████| 10/10 [00:02<00:00,  4.56it/s, loss=0.213]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.80it/s, accuracy=0.902]

Epoch 5



Training: 100%|██████████| 10/10 [00:02<00:00,  4.05it/s, loss=0.142]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.68it/s, accuracy=0.926]

Epoch 6



Training: 100%|██████████| 10/10 [00:02<00:00,  4.67it/s, loss=0.173]
Validation: 100%|██████████| 100/100 [00:15<00:00,  6.64it/s, accuracy=0.878]

Epoch 7



Training: 100%|██████████| 10/10 [00:02<00:00,  4.65it/s, loss=0.0894]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.82it/s, accuracy=0.964]

Epoch 8



Training: 100%|██████████| 10/10 [00:02<00:00,  4.56it/s, loss=0.0816]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.72it/s, accuracy=0.963]

Epoch 9



Training: 100%|██████████| 10/10 [00:02<00:00,  4.59it/s, loss=0.0645]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.74it/s, accuracy=0.938]

Epoch 10



Training: 100%|██████████| 10/10 [00:02<00:00,  4.57it/s, loss=0.0474]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.75it/s, accuracy=0.963]

Epoch 11



Training: 100%|██████████| 10/10 [00:02<00:00,  4.62it/s, loss=0.0277]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.88it/s, accuracy=0.932]

Epoch 12



Training: 100%|██████████| 10/10 [00:02<00:00,  4.72it/s, loss=0.0224]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.70it/s, accuracy=0.921]

Epoch 13



Training: 100%|██████████| 10/10 [00:02<00:00,  4.70it/s, loss=0.0481]
Validation: 100%|██████████| 100/100 [00:15<00:00,  6.65it/s, accuracy=0.955]

Epoch 14



Training: 100%|██████████| 10/10 [00:02<00:00,  4.10it/s, loss=0.0437]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.80it/s, accuracy=0.962]

Epoch 15



Training: 100%|██████████| 10/10 [00:02<00:00,  4.59it/s, loss=0.0471]
Validation: 100%|██████████| 100/100 [00:15<00:00,  6.62it/s, accuracy=0.965]

Epoch 16



Training: 100%|██████████| 10/10 [00:02<00:00,  4.61it/s, loss=0.0209]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.77it/s, accuracy=0.957]

Epoch 17



Training: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s, loss=0.0135]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.77it/s, accuracy=0.943]

Epoch 18



Training: 100%|██████████| 10/10 [00:02<00:00,  4.69it/s, loss=0.0103]
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.79it/s, accuracy=0.916]

Epoch 19



Training: 100%|██████████| 10/10 [00:02<00:00,  4.60it/s, loss=0.015] 
Validation: 100%|██████████| 100/100 [00:14<00:00,  6.82it/s, accuracy=0.935]


In [None]:
new_model.load_state_dict(final_state)
accuracy = evaluate(new_model, final_loader, device=DEVICE)


In [36]:
print(f"Average accuracy: {100 * accuracy[0]:.2f} %")
print(f"Precision: {100 * accuracy[1]:.2f} %")
print(f"Recall: {100 * accuracy[2]:.2f} %")
print(f"F1-score: {100 * accuracy[3]:.2f} %")

Average accuracy: 96.71 %
Precision: 96.81 %
Recall: 96.71 %
F1-score: 96.69 %


In [None]:
new_model.laod_state_dict(final_state)