In [28]:
import numpy as np
import os
from scipy.io import loadmat
import matplotlib.pyplot as plt
import seaborn as sns

import pickle
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import timeit
from  torch.utils.data import DataLoader
from torch.utils.data import random_split
from tqdm.notebook import tqdm
from sklearn import preprocessing
from collections import defaultdict

import time
import random
import collections
from copy import deepcopy

In [3]:
# naming convention: x0_y0_p100_b40_s1_t1.mat
# room_scenarios = [3,4,5,8,10,11,12,13,21]
# corridor_scenarios = [16,17,18,24,25,30,31,32,33]
# open_scenarios = [6,7,9,14,15,19,20,26,27]

In [5]:
data_folder  = 'data/processed/'
file_names = os.listdir(data_folder)
grid_width = 60
num_bursts_each_file =40

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [30]:
from torch.utils.data import Dataset
from torchvision import datasets, transforms
 
    
class CsiMetaDataset(Dataset):
    def __init__(self, data_folder, scenario, dataset_type, shots=None):
         
        if dataset_type not in ['linear', '1dConv']:
            raise ValueError('Wrong type of dataset')
        
        file_names = os.listdir(data_folder)
        self.list_samples = [] 
        
        for file_name in file_names:
            
            if f't{scenario}.mat' not in file_name:
                continue
            x = int(file_name.split('_')[0][1:])
            y = int(file_name.split('_')[1][1:])
            
            file_data = loadmat(f'{data_folder}{file_name}')
            file_data['list_packets'] = np.real(file_data['list_packets'])
            
            counter = 0
            for packet_id in range(file_data['list_packets'].shape[2]): #TODO: don't choose all packets from the start
                
                if shots and counter >= shots:
                    break
                    
                if dataset_type == 'linear':
                    self.list_samples.append({'csi_data' : file_data['list_packets'][:,:,packet_id].flatten(),
                                              'coordinates': torch.tensor([x,y])})
                elif dataset_type == '1dConv':
                    self.list_samples.append({'csi_data' : file_data['list_packets'][:,:,packet_id],
                                              'coordinates': torch.tensor([x,y])})
                counter += 1
            
    def __len__(self):
        return len(self.list_samples)
    
    def __getitem__(self, idx):
        return self.list_samples[idx]
                   

In [31]:
train_support_dataloaders = []
train_query_dataloaders = []

test_support_dataloaders = []
test_query_dataloaders = []

class DataloderClass():
    def __init__(self, train_scenarios, test_scenarios, batch_size, shots=5):
        self.train_support_dataloaders = []
        self.train_query_dataloaders = []

        self.test_support_dataloaders = []
        self.test_query_dataloaders = []
        
        
        for scenario in tqdm(train_scenarios):
            dataset = CsiMetaDataset(data_folder, scenario, dataset_type='1dConv', shots=shots)
            self.train_support_dataloaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))

            dataset = CsiMetaDataset(data_folder, scenario, dataset_type='1dConv', shots=300)
            self.train_query_dataloaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))
    
        for scenario in test_scenarios:
            dataset = CsiMetaDataset(data_folder, scenario, dataset_type='1dConv', shots=shots)
            self.test_support_dataloaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))

            dataset = CsiMetaDataset(data_folder, scenario, dataset_type='1dConv', shots=300)
            self.test_query_dataloaders.append(DataLoader(dataset, batch_size=batch_size, shuffle=True))

In [32]:
dropout_rate = 0.1

def eulicidain_distance_loss(pred, target):
    return torch.mean(torch.sqrt(torch.sum(
        torch.square(pred - target), axis= 1)))

class CnnLocalization(torch.nn.Module):
    def __init__(self):
        super().__init__()       
        
        self.layers = nn.Sequential(
            nn.Conv1d(3,10,3,padding=1),
            nn.MaxPool1d(2),
            nn.ReLU(True),
            
            nn.Conv1d(10,15,3,padding=1),
            nn.MaxPool1d(2),
            nn.ReLU(True),
            
            nn.Flatten(),
            
            nn.Linear(105, 128),
            nn.ReLU(True),
            nn.Linear(128, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, 8),
            nn.ReLU(True),
            nn.Linear(8, 2),
            
        ) 
        

    def forward(self, x):
        return self.layers(x)

    def functional_forward(self, x, params):
        
        x = F.conv1d(x, weight=params['layers.0.weight'], bias=params['layers.0.bias'], padding=1)
        x = F.max_pool1d(x, kernel_size=2)
        x = F.relu(x)
        
        x = F.conv1d(x, weight=params['layers.3.weight'], bias=params['layers.3.bias'], padding=1)
        x = F.max_pool1d(x, kernel_size=2)
        x = F.relu(x)
        
        x = x.view(x.size(0), -1)
        
        x = F.linear(x, weight=params['layers.7.weight'], bias=params['layers.7.bias'])
        x = F.relu(x)
        
        x = F.linear(x, weight=params['layers.9.weight'], bias=params['layers.9.bias'])
        x = F.relu(x)
        
        x = F.linear(x, weight=params['layers.11.weight'], bias=params['layers.11.bias'])
        x = F.relu(x)
        
        x = F.linear(x, weight=params['layers.13.weight'], bias=params['layers.13.bias'])
        x = F.relu(x)
        
        x = F.linear(x, weight=params['layers.15.weight'], bias=params['layers.15.bias'])
        x = F.relu(x)
        
        x = F.linear(x, weight=params['layers.17.weight'], bias=params['layers.17.bias'])
        x = F.relu(x)
#         print(x)
        return F.linear(x, weight=params['layers.19.weight'], bias=params['layers.19.bias'])
    
cnn_v1 = CnnLocalization().to(device)
# summary(cnn_v1,(3,30))

In [33]:
def plot_loss_over_iteration(training_results):
    plt.figure(figsize=(15,8))
    plt.plot(np.array(training_results[0])[:,0]*60, label="Meta Train-Support Loss")
    plt.plot(np.array(training_results[0])[:,1]*60, label="Meta Train-Query Loss")
    plt.plot(np.array(training_results[1])[:,0]*60, label="Meta Test-Support Loss")
    plt.plot(np.array(training_results[1])[:,1]*60, label="Meta Test-Query Loss")
    plt.legend()

In [47]:
class MAML():
    def __init__(self, inner_step_size:int , inner_epochs:int , outer_step_size:int,
                  first_order:bool, scheduler_gamma=0.9,
                inner_loop_print_frequncy=5):
        
        self.inner_step_size = inner_step_size
        self.inner_epochs = inner_epochs
        self.outer_step_size = outer_step_size
        self.first_order = first_order
        
        self.inner_loop_print_frequncy = inner_loop_print_frequncy
        
        self.time_start = 0
        
        self.model = CnnLocalization().to(device)
        self.model_params = list(self.model.parameters()) #is model_params updated?
        
        self.loss_func = nn.MSELoss()#eulicidain_distance_loss
        
        self.meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=outer_step_size)
        self.scheduler = optim.lr_scheduler.StepLR(self.meta_optimizer, step_size=30, gamma=scheduler_gamma)
    
    
    def set_dataloader(self, dataloader):
        self.dataloader = dataloader
    
    def print_epoch_results(self, epoch, loss, is_train:bool):
        """
        loss = [support_loss, query_loss]
        """
        mode = "Meta Training"
        if not is_train:
            mode = "Meta Test"
            
        print(f"Epoch {epoch}: {mode} Support Loss: {loss[0]:.5f}  Query Loss {loss[1]:.5f}")
    
    
    def get_runtime(self):
        return timeit.default_timer() - self.time_start
    
    
    def train(self, task_index):
        self.model.train()
        
        test_loss = 0
        mean_outer_loss = torch.tensor(0., device=device)
        
            
        #Train Support phase
        adapt_params = collections.OrderedDict(self.model.named_parameters())   

        for inner_epoch in range(self.inner_epochs):    

            total_inner_loss = 0
            support_data_length = 0
            for batch in self.dataloader.train_support_dataloaders[task_index]: #Not clean  
                
                support_data_length += len(batch['coordinates'])

                pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                     params=adapt_params)

                inner_loss = self.loss_func(pred, batch['coordinates'].type(torch.FloatTensor).to(device))
                start = time.time()
        
                adapt_grad = torch.autograd.grad(inner_loss, adapt_params.values(),
                                         create_graph= not self.first_order)
                print("time ended: ", {time.time() - start})
                adapt_params = collections.OrderedDict((name, param - self.inner_step_size * grads)
                                               for ((name, param), grads) in zip(adapt_params.items(), adapt_grad))

                total_inner_loss += inner_loss.item()*len(batch['coordinates'])
            
            total_inner_loss /= support_data_length
            
            if self.inner_loop_print_frequncy and inner_epoch % self.inner_loop_print_frequncy == 0:
                print("     ------------------------------")
                print(f"Epoch: {inner_epoch} Total Inner Loss : {total_inner_loss}")
                print(pred[:2])
                print("     ------------------------------")    
        
        #Train Query phase
        query_data_length = 0
        
        for batch in self.dataloader.train_query_dataloaders[task_index]:
            with torch.set_grad_enabled(True): #for training
                test_pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                          params=adapt_params)
                mean_outer_loss += len(batch['coordinates']) * self.loss_func(test_pred,
                                                  batch['coordinates'].type(torch.FloatTensor).to(device))        
                
            query_data_length += len(batch['coordinates'])
    
        mean_outer_loss.div_(query_data_length)
        
        self.meta_optimizer.zero_grad()
        start = time.time()
        mean_outer_loss.backward()
        self.meta_optimizer.step()
        self.scheduler.step()
        print("time ended: ", {time.time() - start})
        raise Exception()
        
        
        return  total_inner_loss, mean_outer_loss.item()
    
    
    def test(self, task_index):
        self.model.train()
        
        test_loss = 0
        mean_outer_loss = torch.tensor(0., device=device)
           
        #Test Support phase
        adapt_params = collections.OrderedDict(self.model.named_parameters())   

        for inner_epoch in range(self.inner_epochs):    

            total_inner_loss = 0
            support_data_length = 0
            for batch in self.dataloader.test_support_dataloaders[task_index]: #Not clean  
                
                support_data_length += len(batch['coordinates'])

                pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                     params=adapt_params)

                inner_loss = self.loss_func(pred, batch['coordinates'].type(torch.FloatTensor).to(device))

                adapt_grad = torch.autograd.grad(inner_loss, adapt_params.values(),
                                         create_graph= not self.first_order)

                adapt_params = collections.OrderedDict((name, param - self.inner_step_size * grads)
                                               for ((name, param), grads) in zip(adapt_params.items(), adapt_grad))

                total_inner_loss += inner_loss.item()*len(batch['coordinates'])
            
            total_inner_loss /= support_data_length
            
            if self.inner_loop_print_frequncy and inner_epoch % self.inner_loop_print_frequncy == 0:
                print("      ------------------------------")
                print(f"Epoch: {inner_epoch} Total Inner Loss : {total_inner_loss}")
                print(batch['coordinates'][:2])
                print("      ------------------------------")
        
        #Test Query phase
        query_data_length = 0
        
        pred_gtruth_sets = []
        
        for batch in self.dataloader.test_query_dataloaders[task_index]:
            test_pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                      params=adapt_params)
            mean_outer_loss += len(batch['coordinates']) * self.loss_func(test_pred,
                                              batch['coordinates'].type(torch.FloatTensor).to(device))        
            
            pred_gtruth_sets.append((test_pred, batch['coordinates'].type(torch.FloatTensor).to(device)))
            
            query_data_length += len(batch['coordinates'])

        mean_outer_loss.div_(query_data_length)
        

        return  total_inner_loss, mean_outer_loss.item(), pred_gtruth_sets
    
    

    def run(self, iterations, train=False, test=False):
        self.time_start = timeit.default_timer()

        list_meta_train_loss = []
        list_meta_test_loss = []
        list_runtimes = []
        
        for iteration in tqdm(range(iterations)):
            
            pred_gtruth_sets_all_tasks = []
            
            print(f"================={iteration}=================")
            if train:
                train_task_losses = []
                
                task_indices = list(range(0, len(self.dataloader.train_support_dataloaders), 1))
                random.shuffle(task_indices)
                for task_index in task_indices:
                    
                    train_support_loss, train_query_loss = self.train(task_index)
                    train_task_losses.append([train_support_loss, train_query_loss])
                    
                    print(f"      Epoch:{iteration} Training Task {task_index} Support Loss {train_support_loss} Query Loss {train_query_loss}")
                    
                list_meta_train_loss.append(np.array(train_task_losses).mean(axis=0))
                self.print_epoch_results(iteration, list_meta_train_loss[-1], True)
                
            if test:
                
                test_task_losses = []
                
                task_indices = list(range(0, len(self.dataloader.test_support_dataloaders), 1))
                random.shuffle(task_indices)
                
                for task_index in task_indices:
                    
                    test_support_loss, test_query_loss, pred_gtruth_sets = self.test(task_index)
                    test_task_losses.append([test_support_loss, test_query_loss])
                    
                    pred_gtruth_sets_all_tasks.append(pred_gtruth_sets)
                    
                    print(f"      Epoch:{iteration} Testing Task {task_index} Support Loss {test_support_loss} Query Loss {test_query_loss}")
                    
                list_meta_test_loss.append(np.array(test_task_losses).mean(axis=0))
                self.print_epoch_results(iteration, list_meta_test_loss[-1], False)
                
            if train:
                list_runtimes.append(self.get_runtime())
                print(f"Runtime: {self.get_runtime():.2f}, \n")
            print("========================================")
            
            
        return  list_meta_train_loss, list_meta_test_loss, list_runtimes, pred_gtruth_sets_all_tasks
        

In [48]:
class Adaptive_MAML():
    def __init__(self, inner_step_size:int , inner_epochs:int , outer_step_size:int,
                 first_order:bool, scheduler_gamma=0.9,
                 scheduler_step_size=30, inner_loop_print_frequncy=5):
        
        self.inner_step_size = inner_step_size
        self.inner_epochs = inner_epochs
        
        self.outer_step_size = outer_step_size
        self.iteration_step_size = self.outer_step_size
        self.scheduler_step_size = scheduler_step_size
        self.scheduler_gamma = scheduler_gamma
        
        self.first_order = first_order
        
        self.inner_loop_print_frequncy = inner_loop_print_frequncy
        
        self.time_start = 0
        
        self.model = CnnLocalization().to(device)
        self.model_params = list(self.model.parameters()) #is model_params updated?
        
        self.loss_func = nn.MSELoss()#eulicidain_distance_loss
        
        self.meta_optimizer = torch.optim.Adam(self.model.parameters(), lr=outer_step_size)
        
        self.adaptive_step_sizes = calcualte_adaptive_step_size(center_step_size=self.outer_step_size, 
                                                                importance_vector=read_importance_vector())
    
    def set_dataloader(self, dataloader):
        self.dataloader = dataloader
    
    def print_epoch_results(self, epoch, loss, is_train:bool):
        """
        loss = [support_loss, query_loss]
        """
        mode = "Meta Training"
        if not is_train:
            mode = "Meta Test"
            
        print(f"Epoch {epoch}: {mode} Support Loss: {loss[0]:.5f}  Query Loss {loss[1]:.5f}")
    
    
    def get_runtime(self):
        return timeit.default_timer() - self.time_start
    
    
    def train(self, task_index):
        self.model.train()
        
        test_loss = 0
        mean_outer_loss = torch.tensor(0., device=device)
        
            
        #Train Support phase
        adapt_params = collections.OrderedDict(self.model.named_parameters())   

        for inner_epoch in range(self.inner_epochs):    

            total_inner_loss = 0
            support_data_length = 0
            for batch in self.dataloader.train_support_dataloaders[task_index]: #Not clean  
                
                support_data_length += len(batch['coordinates'])

                pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                     params=adapt_params)

                inner_loss = self.loss_func(pred, batch['coordinates'].type(torch.FloatTensor).to(device))

                adapt_grad = torch.autograd.grad(inner_loss, adapt_params.values(),
                                         create_graph= not self.first_order)

                adapt_params = collections.OrderedDict((name, param - self.inner_step_size * grads)
                                               for ((name, param), grads) in zip(adapt_params.items(), adapt_grad))

                total_inner_loss += inner_loss.item()*len(batch['coordinates'])
            
            total_inner_loss /= support_data_length
            
            if self.inner_loop_print_frequncy and inner_epoch % self.inner_loop_print_frequncy == 0:
                print("     ------------------------------")
                print(f"Epoch: {inner_epoch} Total Inner Loss : {total_inner_loss}")
                print(pred[:2])
                print("     ------------------------------")    
        
        #Train Query phase
        query_data_length = 0
        
        for batch in self.dataloader.train_query_dataloaders[task_index]:
            with torch.set_grad_enabled(True): #for training
                test_pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                          params=adapt_params)
                mean_outer_loss += len(batch['coordinates']) * self.loss_func(test_pred,
                                                  batch['coordinates'].type(torch.FloatTensor).to(device))        
                
            query_data_length += len(batch['coordinates'])
    
        mean_outer_loss.div_(query_data_length)
        
        
        for g in self.meta_optimizer.param_groups:
            g['lr'] = self.adaptive_step_sizes[task_index]
            
        self.meta_optimizer.zero_grad()
        mean_outer_loss.backward()
        self.meta_optimizer.step()
        
        return  total_inner_loss, mean_outer_loss.item()
    
    
    def test(self, task_index):
        self.model.train()
        
        test_loss = 0
        mean_outer_loss = torch.tensor(0., device=device)
           
        #Test Support phase
        adapt_params = collections.OrderedDict(self.model.named_parameters())   

        for inner_epoch in range(self.inner_epochs):    

            total_inner_loss = 0
            support_data_length = 0
            for batch in self.dataloader.test_support_dataloaders[task_index]: #Not clean  
                
                support_data_length += len(batch['coordinates'])

                pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                     params=adapt_params)

                inner_loss = self.loss_func(pred, batch['coordinates'].type(torch.FloatTensor).to(device))

                adapt_grad = torch.autograd.grad(inner_loss, adapt_params.values(),
                                         create_graph= not self.first_order)

                adapt_params = collections.OrderedDict((name, param - self.inner_step_size * grads)
                                               for ((name, param), grads) in zip(adapt_params.items(), adapt_grad))

                total_inner_loss += inner_loss.item()*len(batch['coordinates'])
            
            total_inner_loss /= support_data_length
            
            if self.inner_loop_print_frequncy and inner_epoch % self.inner_loop_print_frequncy == 0:
                print("      ------------------------------")
                print(f"Epoch: {inner_epoch} Total Inner Loss : {total_inner_loss}")
                print(batch['coordinates'][:2])
                print("      ------------------------------")
        
        #Test Query phase
        query_data_length = 0
        
        pred_gtruth_sets = []
        
        for batch in self.dataloader.test_query_dataloaders[task_index]:
            test_pred = self.model.functional_forward(batch['csi_data'].type(torch.FloatTensor).to(device),
                                                      params=adapt_params)
            mean_outer_loss += len(batch['coordinates']) * self.loss_func(test_pred,
                                              batch['coordinates'].type(torch.FloatTensor).to(device)) 
            
            pred_gtruth_sets.append((test_pred, batch['coordinates'].type(torch.FloatTensor).to(device)))
                
            query_data_length += len(batch['coordinates'])

        mean_outer_loss.div_(query_data_length)
        

        return  total_inner_loss, mean_outer_loss.item(), pred_gtruth_sets
    
    

    def run(self, iterations, train=False, test=False):
        self.time_start = timeit.default_timer()

        list_meta_train_loss = []
        list_meta_test_loss = []
        list_runtimes = []
        
        for iteration in tqdm(range(iterations)):
            
            pred_gtruth_sets_all_tasks = []
            
            if iteration% self.scheduler_step_size == 0:
                self.iteration_step_size *= self.scheduler_gamma
                self.adaptive_step_sizes = calcualte_adaptive_step_size(center_step_size=self.iteration_step_size, 
                                                                        importance_vector=read_importance_vector())
                
                print(self.adaptive_step_sizes)
                
            print(f"================={iteration}=================")
            if train:
                train_task_losses = []
                
                task_indices = list(range(0, len(self.dataloader.train_support_dataloaders), 1))
                random.shuffle(task_indices)
                for task_index in task_indices:
                    
                    train_support_loss, train_query_loss = self.train(task_index)
                    train_task_losses.append([train_support_loss, train_query_loss])
                    
                    print(f"      Epoch:{iteration} Training Task {task_index} Support Loss {train_support_loss} Query Loss {train_query_loss}")
                    
                list_meta_train_loss.append(np.array(train_task_losses).mean(axis=0))
                self.print_epoch_results(iteration, list_meta_train_loss[-1], True)
                
            if test:
                
                test_task_losses = []
                
                task_indices = list(range(0, len(self.dataloader.test_support_dataloaders), 1))
                random.shuffle(task_indices)
                
                for task_index in task_indices:
                    
                    test_support_loss, test_query_loss, pred_gtruth_sets  = self.test(task_index)
                    test_task_losses.append([test_support_loss, test_query_loss])
                    
                    pred_gtruth_sets_all_tasks.append(pred_gtruth_sets)
                    
                    print(f"      Epoch:{iteration} Testing Task {task_index} Support Loss {test_support_loss} Query Loss {test_query_loss}")
                    
                list_meta_test_loss.append(np.array(test_task_losses).mean(axis=0))
                self.print_epoch_results(iteration, list_meta_test_loss[-1], False)
                
            if train:
                list_runtimes.append(self.get_runtime())
                print(f"Runtime: {self.get_runtime():.2f}, \n")
            print("========================================")
            
            
        return  list_meta_train_loss, list_meta_test_loss, list_runtimes, pred_gtruth_sets_all_tasks

def read_importance_vector():
    return np.average(np.load('results/conventional learning/5_shot_additional_learning.npy')*grid_width, axis=1)

def calcualte_adaptive_step_size(center_step_size, importance_vector):
    
    importance_vector = (importance_vector - importance_vector.min())/(importance_vector.max() - importance_vector.min()) #minmax
    importance_vector = -(importance_vector -0.5 )*2 # range(-1,1), hgiher value less important

    
    importance_vector = [np.tanh(x) for x in importance_vector]
    adaptive_stepsizes = [center_step_size + importance/1000 for importance in importance_vector]
    return adaptive_stepsizes
    
adaptive_step_sizes = calcualte_adaptive_step_size(0.002, read_importance_vector())

In [39]:
def generate_test_tasks(n_training_tasks, n_experiments):
    res = []
    for _ in range(n_experiments):
        res.append(random.sample(list(range(1,34,1)), k=34-n_training_tasks))
    
    return res

generate_test_tasks(n_training_tasks=30,n_experiments=5)

[[31, 23, 14, 33],
 [9, 19, 7, 17],
 [10, 20, 7, 5],
 [22, 31, 7, 23],
 [28, 21, 14, 31]]

In [40]:
results_directory = 'results/category_experiment/'

In [49]:
def run_experiment(model, seed, shots, outer_step_size,
                   inner_step_size, inner_epochs, scheduler_gamma,iterations,
                  batch_size, n_experiments, train_test_scenarios, experiment_name, scheduler_step_size=10):
    
    set_random_seed(seed)
    
    train_scenarios, test_scenarios = train_test_scenarios
    
    df_pred_truths = pd.DataFrame(columns=['truth_x', 'truth_y', 'pred_x', 'pred_y'])
    df_results = pd.DataFrame()
    
    for i in range(n_experiments):
        
        dataloader = DataloderClass(train_scenarios, test_scenarios, batch_size, shots)
        
        if model == 'MAML':
            meta_model = MAML(inner_step_size=inner_step_size , inner_epochs = inner_epochs , outer_step_size=outer_step_size,
                          first_order=False, scheduler_gamma=scheduler_gamma,
                         inner_loop_print_frequncy=0)
        elif model == 'ADAMAML':
            meta_model = Adaptive_MAML(inner_step_size=inner_step_size , inner_epochs = inner_epochs ,
                               outer_step_size=outer_step_size, first_order=False
                 , scheduler_gamma=scheduler_gamma, scheduler_step_size=scheduler_step_size,
                 inner_loop_print_frequncy=0)
        elif model == 'FOMAML':
            meta_model = MAML(inner_step_size=inner_step_size , inner_epochs = inner_epochs , outer_step_size=outer_step_size,
                          first_order=True, scheduler_gamma=scheduler_gamma,
                         inner_loop_print_frequncy=0)
        else:
            raise Exception("Wrong Model")
        
        meta_model.set_dataloader(dataloader)
        training_results = meta_model.run(iterations=iterations, train=True, test=True)

        df_results[f'test_{i}_train_support'] = np.array(training_results[0])[:,0]
        df_results[f'test_{i}_train_query'] = np.array(training_results[0])[:,1]
        df_results[f'test_{i}_test_support'] = np.array(training_results[1])[:,0]
        df_results[f'test_{i}_test_query'] = np.array(training_results[1])[:,1]
        df_results[f'test_{i}_runtime'] = training_results[2]

        pred_truths = training_results[3]
        
        for pred_truths_lvl1 in pred_truths:
            for pred_truths_lvl2 in pred_truths_lvl1:
                data = {'truth_x': pred_truths_lvl2[1][:,0].tolist(),
                'truth_y': pred_truths_lvl2[1][:,1].tolist(),
                'pred_x': pred_truths_lvl2[0][:,0].tolist(),
                'pred_y':pred_truths_lvl2[0][:,1].tolist()
                }
                
                
                df = pd.DataFrame(data)
                
                if len(df['truth_x']) == len(df['pred_x']):
                    df_pred_truths = pd.concat([df_pred_truths, df], ignore_index=True)


    df_pred_truths['error'] = df_pred_truths.apply(lambda row:((row['truth_x']-row['pred_x'])**2 + (row['truth_y']-row['pred_y'])**2)**0.5 , axis=1)

    file_name = f'/{model}_{shots}shot_{inner_step_size}stepinner_{outer_step_size}stepouter_{inner_epochs}epochinner_{batch_size}batch_{scheduler_gamma}schedulegamme_{experiment_name}_{scheduler_step_size}schedulestep{seed}seed.csv'
    df_results.to_csv(f'{results_directory}/Losses/{file_name}', index=False)
    df_pred_truths.to_csv(f'{results_directory}/Errors/{file_name}', index=False)

In [None]:
# run_experiment(model='FOMAML', seed=0, shots=5, inner_step_size=0.005,
#                outer_step_size=0.002, inner_epochs=10, scheduler_gamma=0.98,
#                batch_size=32, iterations=100, n_experiments=10, train_test_scenarios=(mixed_scenarios_train,mixed_scenarios_test), experiment_name="mixed")
