# NOTEBOOK FOR  TRAINING MANET MODEL ON ELECTRO-L №2 DATA

## IMPORT ALL REQUIRED PACKAGES

In [1]:
import torch
import torch.nn.functional as F
import torchmetrics
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import pandas as pd
import patoolib
from copy import deepcopy
from tqdm import tqdm
from timeit import default_timer as timer 
from PIL import Image
import numpy as np
import gc
import warnings
import os
import matplotlib.pyplot as plt
from torch import nn
from IPython.display import clear_output
from torcheval.metrics.functional import multiclass_f1_score
import random
import glob
import tifffile as tff
from skimage.transform import resize as interp_resize
import segmentation_models_pytorch as smp
import torch.optim as optim
import rasterio
from ranger21 import Ranger21
import kornia as K
from kornia.augmentation.container import AugmentationSequential
from kornia.augmentation import (
    RandomAffine,
    RandomElasticTransform,
    RandomHorizontalFlip,
    RandomPerspective,
    RandomRotation,
    RandomVerticalFlip)
%matplotlib inline
warnings.filterwarnings("ignore")

## Initilize your current directory and device (CUDA recommended)

In [2]:
np.set_printoptions(threshold=1e7)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
your_current_dir = os.getcwd()
print('Your current dir of this .ipynb file',your_current_dir)
print('Your device:',device)
your_current_dir = your_current_dir.replace('training_model_utils','')

Your current dir of this .ipynb file C:\Users\nikita.belyakov\Documents\GitHub\CLOUD_SNOW_SEGMENTATION\RES_4_KM\MANet_training
Your device: cuda


## Set random seeds for stability everywhere

In [3]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(seed=42)

## Define functions for preprocessing and postprocessing multispectral data as pytorch tensors

In [4]:
def batch_to_img(xb, idx=0):
    img = np.array(xb.squeeze(0))
    return np.array(img.transpose((1,2,0))*255)
def transpose_patch(patch):
    tr = np.array(patch).astype(np.uint8).transpose((2,0,1))
    return tr
def predb_to_mask(predb,idx=0):
    p = torch.functional.F.softmax(predb.squeeze(0), 0)
    return p.argmax(0).cpu()
def inverse_normalize(tensor,mean,std,num_shannels =5):
    tensor = tensor.detach().cpu().numpy()
    # unnormalize the RGB channels
    for i in range(num_shannels):
        tensor[i] = (tensor[i] * std[i]) + mean[i]
    # clip values to [0, 1] range
    tensor = np.clip(tensor, 0, 1)
    # convert back to uint8
    tensor = (tensor * 255).astype(np.uint8)
    return tensor
def open_sample_as_pil_no_norm(datacube):
    inversed_rgb = (datacube[0:3])
    rgb = np.dstack((inversed_rgb[2,:,:],inversed_rgb[1,:,:],inversed_rgb[0,:,:]))*255
    return rgb.astype('uint8')
def open_DEM(datacube):
    dem_arr = datacube[5]#*255
    return dem_arr.numpy()#.astype('uint8')
def open_mask_as_pil(seglabel):
    mask = (seglabel.numpy())
    return mask

## MODEL INITIALIZATION

In [5]:
# lets initialize  architecture of segmentation FOR ALL 12 channels
model = smp.MAnet(
    encoder_name='efficientnet-b0', 
    in_channels=12,
    classes=3,)

## CREATE A DATASET CLASS WITH ELECTRO-L №2 with masks from GOES, METEOSAT & Terra/MODIS for training**

In [8]:
class ELECTRO_L2_Dataset_(Dataset):
    def __init__(self, stack_dir_list,aux_dir, pytorch=True, include_BT = True,nonempty_mode = True):
        super().__init__()
        self.pytorch = pytorch   
        self.nonempty_mode = nonempty_mode
        self.include_BT = include_BT
        self.stack_dir_list = stack_dir_list
        self.non_empty_list = []
        self.snowy_list = []
        self.stack_dirs = []
        self.stack_dirs_BT = []
        self.all_snow_flags = []
        self.non_empty_snow_idxs = []
        self.non_empty_aux = []
        self.aux_dir = aux_dir
        total_len = 0
        for i in range(len(self.stack_dir_list)): # iterate over number of pictures in dataset (1 now)
            stack_dir = self.stack_dir_list[i]
            stack_dir_BT = self.stack_dir_list[i].replace('rgb', 'BT').replace('folder_ZSA','folder_BT')
            non_empty_idxs_dir = self.stack_dir_list[i].replace('all_patch_folder_ZSA','nonempty_idxs_folder')
            non_empty_idxs_dir = non_empty_idxs_dir.replace('patches_rgb','nonempty_idxs')
            snowy_idxs_dir = non_empty_idxs_dir.replace('nonempty_idxs_folder_ZSA', 'snowy_idxs_folder')
            snowy_idxs_dir = snowy_idxs_dir.replace('nonempty', 'snowy')
            snow_flag_array = np.load(snowy_idxs_dir+'.npy')
            self.all_snow_flags.append(snow_flag_array)
            if self.nonempty_mode:
                non_empty_idxs = np.load(non_empty_idxs_dir+'.npy')
                snow_idxs_nonempty = np.zeros(len(snow_flag_array))
                snow_idxs_nonempty[non_empty_idxs]=1
                snow_idxs_nonempty = (snow_idxs_nonempty*snow_flag_array).astype(np.uint8)
                for j in range(len(non_empty_idxs)):
                    idx_ = non_empty_idxs[j]
                    if snow_idxs_nonempty[idx_] == 1:
                        self.non_empty_snow_idxs.append(total_len+j)
                    patch_file = f"{stack_dir}/patch_{idx_}.tif"
                    self.stack_dirs.append(patch_file)
                    patch_file_BT = f"{stack_dir_BT}/patch_{idx_}.tif"
                    self.stack_dirs_BT.append(patch_file_BT)
                    self.non_empty_aux.append(idx_)
                total_len = total_len+len(non_empty_idxs)
            else:
                patch_files = glob.glob(stack_dir+'/*')
                patch_files_BT = glob.glob(stack_dir_BT+'/*')
                for patch_file in patch_files:
                    self.stack_dirs.append(patch_file)
                for patch_file_BT in patch_files_BT:
                    self.stack_dirs_BT.append(patch_file_BT) 
        self.all_snow_flags_nonempty =np.zeros(total_len).astype(np.uint8)
        self.all_snow_flags_nonempty[self.non_empty_snow_idxs]=1
        self.all_snow_flags = np.array(self.all_snow_flags).flatten()
        
    def __len__(self):
        num_patches = int((len(self.stack_dirs))) 
        return num_patches
    
    def open_rgb_normed(self, idx, invert=False):
        patch_file = self.stack_dirs[idx]
        patch = tff.imread(patch_file)
        r,g,b = patch[:,:,0]/255, patch[:,:,1]/255, patch[:,:,2]/255
        patch_normed  = np.dstack([b, g ,r])
        return patch_normed 
    
    def open_BT_normed(self, idx, invert=False):
        patch_file_BT = self.stack_dirs_BT[idx]
        patch_BT = tff.imread(patch_file_BT)
        return patch_BT 
    
    def open_aux(self, idx, invert=False):
        patch_file = self.aux_dir+'/patch_'+str(self.non_empty_aux[idx])+'.tif'
        patch = tff.imread(patch_file)
        lon,lat,dem = patch[:,:,0],patch[:,:,1],patch[:,:,2]
        lon,lat = (lon+180.0)/360.0,(lat+90.0)/180.0
        dem = dem/10000
        lonlatdem_normed = np.dstack([lon,lat,dem])
        return lonlatdem_normed
    
    def open_mask(self, idx):
        patch_file = self.stack_dirs[idx]
        mask_file = patch_file.replace('rgb','masks').replace('tif','jpg').replace('all_patch_folder_ZSA','all_masks_folder')
        mask = tff.imread(mask_file)
        snow_mask = mask[:,:,1]//255 #  1 - snow class
        cloud_mask = mask[:,:,0] # 2 - cloud class
        cloud_mask[cloud_mask!=0] = 2
        bg = mask[:,:,2]
        bin_mask = snow_mask + cloud_mask
        return bin_mask
    
    def open_as_pil(self, idx):
        patch_file = self.stack_dirs[idx]
        patch_image = tff.imread(patch_file)
        return patch_image 
    
    def __getitem__(self, idx):
        patch_file = self.stack_dirs[idx]
        mask_file = patch_file.replace('rgb','masks').replace('tif','jpg').replace('all_patch_folder_ZSA','all_masks_folder')
        mask = tff.imread(mask_file)
        snow_mask = mask[:,:,1]//255 #  1 - snow class
        cloud_mask = mask[:,:,0] # 2 - cloud class
        cloud_mask[cloud_mask!=0]=2
        bin_mask = snow_mask+cloud_mask
        patch = tff.imread(patch_file)
        r,g,b = patch[:,:,0]/255, patch[:,:,1]/255, patch[:,:,2]/255
        aux_file = self.aux_dir+'/patch_'+str(self.non_empty_aux[idx])+'.tif'
        aux = tff.imread(aux_file)
        lon,lat,dem = aux[:,:,0],aux[:,:,1],aux[:,:,2]
        lon,lat = (lon+180.0)/360.0,(lat+90.0)/180.0
        dem = dem/10000
        aux_normed = np.stack([lon, lat, dem])
        patch_normed  = np.stack([b, g ,r])
        if self.include_BT:
            patch_file_BT = self.stack_dirs_BT[idx] 
            BT_normed = tff.imread(patch_file_BT).transpose(2,0,1)
            full_stack = np.concatenate((patch_normed, BT_normed, aux_normed), axis=0)
            # order of channels: b, g, r, BT4, BT5, BT6, BT7, BT8, BT9, lon, lat, dem
            full_stack = torch.tensor(full_stack, dtype=torch.float32)
        else:
            full_stack = np.concatenate((patch_normed, aux_normed), axis=0)
            full_stack = torch.tensor(full_stack, dtype=torch.float32)
        return full_stack, torch.tensor(bin_mask).long()

## INITIALIZE ELECTRO-L №2 DATASET

In [18]:
data_current_dir = your_current_dir.replace('MANet_training','data_inference')
stack_dir = glob.glob(data_current_dir+'/all_patch_folder_ZSA/*')
aux_dir_l2 = data_current_dir + '/lon_lat_dem'
Electro_ds = ELECTRO_L2_Dataset_(stack_dir_list=stack_dir, aux_dir = aux_dir_l2,nonempty_mode = True, include_BT = True)

C:\Users\nikita.belyakov\Documents\GitHub\CLOUD_SNOW_SEGMENTATION\RES_4_KM\data_inference ['C:\\Users\\nikita.belyakov\\Documents\\GitHub\\CLOUD_SNOW_SEGMENTATION\\RES_4_KM\\data_inference/all_patch_folder_ZSA\\patches_rgb_electro_l2_20230115_1400']


## DEFINE A PIPELINE CLASS OF GEOMETRIC TRANSFORM ON GPU USING Kornia

In [19]:
import kornia as K
from kornia.augmentation.container import AugmentationSequential
class Geom_Augmentation(nn.Module):
    def __init__(self):
        super(Geom_Augmentation, self).__init__()
        # we define and cache our operators as class members
        self.augs = AugmentationSequential(
                    RandomVerticalFlip(p=1),
                    RandomHorizontalFlip(p=1),
                    RandomPerspective(0.25, sampling_method = 'area_preserving', p=1.),
                    RandomAffine(degrees =(-85.0,85.0),translate = None,scale = (0.9, 1.1),resample="nearest",shear = None,padding_mode="reflection",align_corners=True,same_on_batch=False,keepdim=True,p=1),
                    RandomElasticTransform(kernel_size=(33, 33), sigma=(6.0, 6.0), alpha=(1.0, 1.0), align_corners=True, resample='nearest', padding_mode='reflection', same_on_batch=False, p=1.0, keepdim=True),
                    data_keys=['input', 'mask'], same_on_batch = False, random_apply = 2)                       
    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # 2. apply geometric tranform
        out = self.augs(img, mask)
        img_out, mask_out = out[0],out[1]
        return img_out, mask_out
geom_augs = Geom_Augmentation()

## USE WEIGHTED SAMPLER TO MAKE DATALOADER MORE BALANCED WITH SNOW PATCHES

In [21]:
#SAMPLER FOR ELECTRO L2 DS FOR FINETUNNING
# Electro_ds_tr, Electro_ds_val  = torch.utils.data.random_split(Electro_ds, (0.8,0.2))
non_empty_snow_patches_ds = Electro_ds.all_snow_flags_nonempty
num_non_empty_snow_patches = len(non_empty_snow_patches_ds)
bg_weight = num_non_empty_snow_patches/len(Electro_ds)
snowy_weight = 1 - num_non_empty_snow_patches/len(Electro_ds)
batch_size = 16
_, counts = torch.unique(torch.tensor(non_empty_snow_patches_ds), return_counts=True)
weights = counts.max() / counts
print("Weights: ", weights)
weight_for_sampler_l2 = []  # Every sample must have a weight
for snow_flag in non_empty_snow_patches_ds:
    weight_for_sampler_l2.append(weights[snow_flag].item())
sampler_l2 = WeightedRandomSampler(torch.tensor(weight_for_sampler_l2), len(Electro_ds))

Weights:  tensor([ 1.0000, 22.4000])


In [27]:
batch_size = 16 #16 # try to set as max as possible
train_dl_ = DataLoader(Electro_ds, batch_size=batch_size, sampler = sampler_l2)
len(Electro_ds)

117

## Define function for train 1 epoch and saving best model according IoU metric value on validation subset

In [28]:
def save_best_model(model,epoch,path = your_current_dir.replace('MANet_training','models/MANet_ep_')):
        model_copy = deepcopy(model)
        best_model = model_copy
        best_model_name = path+str(epoch)
        torch.save(best_model.state_dict(),best_model_name)
        print('best model is on epoch =',epoch)
        return best_model
def train_ep(model, train_dataload, dice, focal, focal_alpha, optimizer,ep_i, best_train_iou, scheduler = None):    
    model.cuda()
    print('epoch_n =',ep_i)
    model.train(True)  # Set train mode = true
    step = 0
    train_loss = 0
    #initialize metrics
    train_f1_score = torchmetrics.F1Score(num_classes=3, task = 'multiclass', average = 'macro').cuda()
    train_iou_score = torchmetrics.JaccardIndex(num_classes=3, task = 'multiclass').cuda()
    train_f1_score_sep = torchmetrics.F1Score(num_classes=3, task = 'multiclass', average = None).cuda()
    train_iou_score_sep = torchmetrics.JaccardIndex(num_classes=3, task = 'multiclass', average = None).cuda()
    # COMMENTED CODE IS FOR ADDING VALIDATION STEP IN A TRAINING LOOP
#     val_f1_score = torchmetrics.F1Score(num_classes=3, task = 'multiclass', average = 'macro').cuda()
#     val_iou_score = torchmetrics.JaccardIndex(num_classes=3, task = 'multiclass').cuda()
#     val_f1_score_sep = torchmetrics.F1Score(num_classes=3, task = 'multiclass', average = None).cuda()
#     val_iou_score_sep = torchmetrics.JaccardIndex(num_classes=3, task = 'multiclass', average = None).cuda()
#     val_acc = torchmetrics.Accuracy(num_classes=3, task = 'multiclass').cuda()
    # iterate over data
    print('-----------training process---------')
    for x,y in tqdm(train_dataload): 
        x = torch.tensor(x).type(torch.float32).cuda()
        y = torch.tensor(y).type(torch.float32).cuda()
        #add geom augmentations on GPU from kornia
        x, y = geom_augs(x, y.unsqueeze(1)) #convert labels to float32 for kornia !!!!
        y = y.type(torch.LongTensor).squeeze(1).cuda() #convert labels to Long again for model !!!!
        step += 1
        # vector graph of training with grad on CUDA
        optimizer.zero_grad()
        output = model(x)
        output = torch.functional.F.softmax(output, 1)
        predictions = output.argmax(dim=1).cuda()
        loss = (1-focal_alpha)*dice(output, y)+focal_alpha*focal(output, y)
        loss.backward()
        optimizer.step()
        # other things can be done on CPU
        loss = loss.detach()
        predictions, y = predictions.detach(), y.detach()
        train_f1_score.update(predictions, y)
        train_iou_score.update(predictions, y)
        train_f1_score_sep.update(predictions, y)
        train_iou_score_sep.update(predictions, y)
        train_loss=train_loss+loss
        x,y = None, None
        # need for torch.no_grad in this training pass
        if scheduler!=None:
            scheduler.step()
    train_loss = train_loss.cpu()/len(train_dataload)
    print('after training epoch mean train loss =',train_loss)
    # Compute the train metrics for the epoch
    train_f1 = train_f1_score.compute()
    train_iou = train_iou_score.compute()
    train_f1_sep = train_f1_score_sep.compute()
    train_iou_sep = train_iou_score_sep.compute()

    # Reset the train metrics objects for the next epoch
    train_f1_score.reset()
    train_iou_score.reset()
    train_f1_score_sep.reset()
    train_iou_score_sep.reset()
# # Inside your epoch training loop, after each batch is processed, compute the metrics on the batch predictions and ground truth
#     with torch.no_grad():
#         val_loss = 0
#         print('-----------validation process---------')
#         for x,y in tqdm(valid_dataload):
#             x = x.cuda()#.cuda()
#             y = y.type(torch.LongTensor).cuda()
#             output = model(x)
#             output = torch.functional.F.softmax(output, 1)
#             # Assuming output has shape (batch_size, num_classes, height, width)
#             # Convert the output to predictions by taking the argmax along the channel dimension
#             predictions = output.argmax(dim=1).cuda()
#             valid_loss = (1-focal_alpha)*dice(output, y)+focal_alpha*focal(output, y)
#             val_loss = val_loss+valid_loss
#             val_f1_score.update(predictions, y)
#             val_iou_score.update(predictions, y)
#             val_f1_score_sep.update(predictions, y)
#             val_iou_score_sep.update(predictions, y)
#             val_acc.update(predictions, y)
#             x,y = None, None
#         val_f1 = val_f1_score.compute()
#         val_iou = val_iou_score.compute()
#         val_acc_ = val_acc.compute()
#         #calculate the same metrics seperately for each class
#         val_f1_sep = val_f1_score_sep.compute()
#         val_iou_sep = val_iou_score_sep.compute()
#         FAR = 1 - val_acc_
            
#         # Compute the validation metrics for the epoch
#         val_loss = val_loss.cpu()/(len(valid_dataload))
#         print('valid loss =',val_loss)

#         # Reset the validation metrics objects for the next epoch
#         val_f1_score.reset()
#         val_iou_score.reset()
#         val_f1_score_sep.reset()
#         val_iou_score_sep.reset()
#         val_acc.reset()
        # Print the F1 score and IoU for the current epoch on the train and validation sets
    print('######### METRICS AFTER TRAINING EPOCH #########')
    print(f"Epoch {ep_i}, Train F1 score: {train_f1:.4f}, Train IoU: {train_iou:.4f}")
    print(f"Train F1 score for each class: {train_f1_sep.cpu().numpy()}, \nTrain IoU for each class: {train_iou_sep.cpu().numpy()}")
#         print(f"Test FAR: {FAR:.4f}")
#         print(f"Test F1 score for each class: {val_f1_sep.cpu().numpy()}, \nTest IoU for each class: {val_iou_sep.cpu().numpy()}")
    if (best_train_iou<train_iou) or (train_iou>0.7):
        save_best_model(model,ep_i)
        print(" model updated")
    return train_f1.cpu(),train_iou.cpu(), train_loss.cpu()

## Train loop MANet ELECTRO-L №2 data with Ranger21 optimizer**

In [30]:
# continue finetunning training process on Electro-L № 2 ds with Ranger21 optimizer
from ranger21 import Ranger21
max_ep_num = 2 # can be set up more 
best_train_iou_ = 0.5
cur_ep = 0
lr, weight_decay = 1e-3, 1e-4
#optimizer = optim.AdamW(model.parameters(),lr = 1e-3)
#scheduler = CosineAnnealingLR(optimizer, T_max=max_ep_num)
optimizer_Ranger21 =  Ranger21(model.parameters(), lr = lr, weight_decay = weight_decay,
                                num_epochs = max_ep_num,num_batches_per_epoch = len(train_dl_))
model.train()
dice = smp.losses.DiceLoss(mode= 'multiclass')
focal = smp.losses.FocalLoss(mode= 'multiclass', gamma = 2)
focal_alpha = 0.7
tr_f1_,tr_iou_,tr_loss_ = [],[best_train_iou_],[]
for ep_i in range(cur_ep,cur_ep+max_ep_num):
    train_f1_,train_iou_,train_loss_= train_ep(model, train_dl_, dice, focal,focal_alpha,optimizer_Ranger21,ep_i,best_train_iou_,scheduler=None)
    tr_f1_.append(train_f1_)
    tr_iou_.append(train_iou_)
    best_train_iou_ = max(tr_iou_)
    tr_loss_.append(train_loss_)

Ranger21 optimizer ready with following settings:

Core optimizer = AdamW
Learning rate of 0.001

Important - num_epochs of training = ** 2 epochs **
please confirm this is correct or warmup and warmdown will be off

Warm-up: linear warmup, over 3 iterations

Lookahead active, merging every 5 steps, with blend factor of 0.5
Norm Loss active, factor = 0.0001
Stable weight decay of 0.0001
Gradient Centralization = On

Adaptive Gradient Clipping = True
	clipping value of 0.01
	steps for clipping = 0.001

Warm-down: Linear warmdown, starting at 72.0%, iteration 11 of 16
warm down will decay until 3e-05 lr
epoch_n = 0
-----------training process---------


 12%|██████████▌                                                                         | 1/8 [00:01<00:13,  1.92s/it]

params size saved
total param groups = 1
total params in groups = 307


 50%|██████████████████████████████████████████                                          | 4/8 [00:07<00:07,  1.92s/it]


** Ranger21 update = Warmup complete - lr set to 0.001



100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.70s/it]


after training epoch mean train loss = tensor(0.6279)
######### METRICS AFTER TRAINING EPOCH #########
Epoch 0, Train F1 score: 0.4646, Train IoU: 0.3458
Train F1 score for each class: [0.7271696  0.06090691 0.60586756], 
Train IoU for each class: [0.5713012  0.03141    0.43458396]
epoch_n = 1
-----------training process---------


 38%|███████████████████████████████▌                                                    | 3/8 [00:05<00:09,  1.88s/it]


** Ranger21 update: Warmdown starting now.  Current iteration = 11....



100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:13<00:00,  1.68s/it]

error in warmdown - lr below min lr. current lr = 2.999999999999997e-05
auto handling but please report issue!
after training epoch mean train loss = tensor(0.6142)
######### METRICS AFTER TRAINING EPOCH #########
Epoch 1, Train F1 score: 0.4887, Train IoU: 0.3784
Train F1 score for each class: [0.76564634 0.03796589 0.66263616], 
Train IoU for each class: [0.62028116 0.01935027 0.49547938]





## BLENDING SEVERAL MODELS via MODELS.SOUP TO RAISE QUALITY

In [290]:
# PREPARE A LIST OF SEVERAL MODELS WITH THE SAME ARCHITECTURE TO SOUP THEIR WEIGHTS
model_path1 = '1st_model_dir'
model_path2 = '2nd_model_dir'
model_path3 = '3rd_model_dir'
model_path_list = [model_path1,model_path2,model_path3]
for i, model_path in enumerate(model_path_list):
    print(model_path)

H:\ELECTRO_DATASET\4_km_res\L2/models/MAnet_Efficient_b0_12_inputs_4km_res_ep_216
H:\ELECTRO_DATASET\4_km_res\L2/models/MAnet_Efficient_b0_12_inputs_4km_res_ep_198
H:\ELECTRO_DATASET\4_km_res\L2/models/MAnet_Efficient_b0_12_inputs_4km_res_ep_175


In [291]:
def uniform_soup(model, path, device = "cpu", by_name = False):
    try:
        import torch
    except:
        print("If you want to use 'Model Soup for Torch', please install 'torch'")
        return model
        
    if not isinstance(path, list):
        path = [path]
    model = model.to(device)
    model_dict = model.state_dict()
    soups = {key:[] for key in model_dict}
    for i, model_path in enumerate(path):
        weight = torch.load(model_path, map_location = device)
        weight_dict = weight.state_dict() if hasattr(weight, "state_dict") else weight
        if by_name:
            weight_dict = {k:v for k, v in weight_dict.items() if k in model_dict}
        for k, v in weight_dict.items():
            soups[k].append(v)
    if 0 < len(soups):
        soups = {k:(torch.sum(torch.stack(v), axis = 0) / len(v)).type(v[0].dtype) for k, v in soups.items() if len(v) != 0}
        model_dict.update(soups)
        model.load_state_dict(model_dict)
    return model
print("\n[Uniform Soup Performance]")
souped_model = uniform_soup(model, model_path_list, device = device)


[Uniform Soup Performance]
-----------Testing process---------


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [01:20<00:00,  1.24it/s]

Test F1 score: 0.8077, Test IoU: 0.7432
Test FAR: 0.1093
Test F1 score for each class(bg, snow, cloud): [0.8006987  0.74544346 0.8768815 ], 
Test IoU for each class(bg, snow, cloud): [0.6924185 0.7397989 0.7972732]





In [292]:
def save_souped_model(model,num_models=3, path =  your_current_dir.replace('MANet_training','models/MANet_souped')):
        model_copy = deepcopy(model)
        best_model = model_copy
        best_model_name = path+'_3_models'
        torch.save(best_model.state_dict(),best_model_name)
        print('souped model saved!')
        return best_model
save_souped_model(souped_model)

souped model saved!


MAnet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      12, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePaddi