# Extending the model with learned weights

In [1]:
import torch
from torchsummary import summary
import numpy as np
import pandas as pd
import os

from unet import UNet  # Convolutional Neural Network model
from functions import chart_cbar, r2_metric, f1_metric, compute_metrics  # Functions to calculate metrics and show the relevant chart colorbar.
from utils import CHARTS, SIC_LOOKUP, SOD_LOOKUP, FLOE_LOOKUP, SCENE_VARIABLES, colour_str

%store -r train_options

In [2]:
train_options = {
    # -- Training options -- #
    'path_to_processed_data': os.environ['AI4ARCTIC_DATA'],  # Replace with data directory path.
    'path_to_env': os.environ['AI4ARCTIC_ENV'],  # Replace with environmment directory path.
    'lr': 0.0001,  # Optimizer learning rate.
    'epochs': 50,  # Number of epochs before training stop.
    'epoch_len': 500,  # Number of batches for each epoch.
    'patch_size': 256,  # Size of patches sampled. Used for both Width and Height.
    'batch_size': 8,  # Number of patches for each batch.
    'loader_upsampling': 'nearest',  # How to upscale low resolution variables to high resolution.
    
    # -- Data prepraration lookups and metrics.
    'train_variables': SCENE_VARIABLES,  # Contains the relevant variables in the scenes.
    'charts': CHARTS,  # Charts to train on.
    'n_classes': {  # number of total classes in the reference charts, including the mask.
        'SIC': SIC_LOOKUP['n_classes'],
        'SOD': SOD_LOOKUP['n_classes'],
        'FLOE': FLOE_LOOKUP['n_classes']
    },
    'pixel_spacing': 80,  # SAR pixel spacing. 80 for the ready-to-train AI4Arctic Challenge dataset.
    'train_fill_value': 0,  # Mask value for SAR training data.
    'class_fill_values': {  # Mask value for class/reference data.
        'SIC': SIC_LOOKUP['mask'],
        'SOD': SOD_LOOKUP['mask'],
        'FLOE': FLOE_LOOKUP['mask'],
    },
    
    # -- Validation options -- #
    'chart_metric': {  # Metric functions for each ice parameter and the associated weight.
        'SIC': {
            'func': r2_metric,
            'weight': 2,
        },
        'SOD': {
            'func': f1_metric,
            'weight': 2,
        },
        'FLOE': {
            'func': f1_metric,
            'weight': 1,
        },
    },
    'num_val_scenes': 10,  # Number of scenes randomly sampled from train_list to use in validation.
    
    # -- GPU/cuda options -- #
    'gpu_id': 0,  # Index of GPU. In case of multiple GPUs.
    'num_workers': 6,  # Number of parallel processes to fetch data.
    'num_workers_val': 1,  # Number of parallel processes during validation.
    
    # -- U-Net Options -- #
    'unet_conv_filters': [16, 32, 32, 32],  # Number of filters in the U-Net.
    'conv_kernel_size': (3, 3),  # Size of convolutional kernels.
    'conv_stride_rate': (1, 1),  # Stride rate of convolutional kernels.
    'conv_dilation_rate': (1, 1),  # Dilation rate of convolutional kernels.
    'conv_padding': (1, 1),  # Number of padded pixels in convolutional layers.
    'conv_padding_style': 'zeros',  # Style of padding.

    # -- Transfer learning options -- #
    'transfer_learning': False,  # Whether to use transfer learning.
    'transfer_model_architecture': {'unet_conv_filters': [16, 32, 32, 32],}, # Dict of the differences in the U-Net options of the model architecture.
    'transfer_model_path': 'archive/model-first_run_4lvl',  # Path to the model to transfer from.
}

In [3]:
# Get GPU resources.
if torch.cuda.is_available():
    print(colour_str('GPU available!', 'green'))
    print('Total number of available devices: ', colour_str(torch.cuda.device_count(), 'orange'))
    device = torch.device(f"cuda:{train_options['gpu_id']}")

else:
    print(colour_str('GPU not available.', 'red'))
    device = torch.device('cpu')

print('GPU setup complete.')

[0;31mGPU not available.[0m
GPU setup complete.


  return torch._C._cuda_getDeviceCount() > 0


We try to load the model saved in `quickstart.ipynb` with the function


`torch.save(obj={'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epoch': epoch},
                        f=f"{train_options['model_codename']}_{epoch}_best_model")`

In [4]:
train_options['path_to_env']

'/home/work/Dokumente/Studium/SimTech_MSc/Erasmus/Lectures/AI_in_industry/project/AI4ArcticSeaIceChallenge/'

In [5]:
model = torch.load('archive/model-first_run_4lvl')
print(model['model_state_dict'].keys())
print(model['epoch'])
print(model['optimizer_state_dict'].keys())

odict_keys(['input_block.double_conv.0.weight', 'input_block.double_conv.1.weight', 'input_block.double_conv.1.bias', 'input_block.double_conv.1.running_mean', 'input_block.double_conv.1.running_var', 'input_block.double_conv.1.num_batches_tracked', 'input_block.double_conv.3.weight', 'input_block.double_conv.4.weight', 'input_block.double_conv.4.bias', 'input_block.double_conv.4.running_mean', 'input_block.double_conv.4.running_var', 'input_block.double_conv.4.num_batches_tracked', 'contract_blocks.0.double_conv.double_conv.0.weight', 'contract_blocks.0.double_conv.double_conv.1.weight', 'contract_blocks.0.double_conv.double_conv.1.bias', 'contract_blocks.0.double_conv.double_conv.1.running_mean', 'contract_blocks.0.double_conv.double_conv.1.running_var', 'contract_blocks.0.double_conv.double_conv.1.num_batches_tracked', 'contract_blocks.0.double_conv.double_conv.3.weight', 'contract_blocks.0.double_conv.double_conv.4.weight', 'contract_blocks.0.double_conv.double_conv.4.bias', 'contr

This is our 4-conv-filters Net `[16, 32, 32, 32]` which we can see when counting the four contract_blocks and the four expand_blocks.

In [6]:
# I want to set the weights of the first 4 layers to the weights of the first 4 layers of the smaller net.
# I want to set the weights of the last 4 layers to the weights of the last 4 layers of the smaller net.

old_weights = list(model['model_state_dict'])

net = UNet(options=train_options)
weights_new = net.state_dict()

for name, tensor in model['model_state_dict'].items():
    if len(tensor.shape) > 1:
        print(f"weights: {name} and {tensor.shape}")
        print(f"unet: {name} and {weights_new[name].shape}")
        if tensor.shape != weights_new[name].shape:
            print('Shape mismatch!')
            print(f"SHAPES: {tensor.shape} and {weights_new[name].shape}")


weights: input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
unet: input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
weights: input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
unet: input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
weights: contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
unet: contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
weights: contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
unet: contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
weights: contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
unet: contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
weights: contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
unet: contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3]

In [7]:
net.load_state_dict(model['model_state_dict'])
summary(net, input_size=(24, train_options['patch_size'], train_options['patch_size']))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 256, 256]           3,456
       BatchNorm2d-2         [-1, 16, 256, 256]              32
              ReLU-3         [-1, 16, 256, 256]               0
            Conv2d-4         [-1, 16, 256, 256]           2,304
       BatchNorm2d-5         [-1, 16, 256, 256]              32
              ReLU-6         [-1, 16, 256, 256]               0
        DoubleConv-7         [-1, 16, 256, 256]               0
         MaxPool2d-8         [-1, 16, 128, 128]               0
            Conv2d-9         [-1, 32, 128, 128]           4,608
      BatchNorm2d-10         [-1, 32, 128, 128]              64
             ReLU-11         [-1, 32, 128, 128]               0
           Conv2d-12         [-1, 32, 128, 128]           9,216
      BatchNorm2d-13         [-1, 32, 128, 128]              64
             ReLU-14         [-1, 32, 1

In [8]:
from unet_transfer import UNetTrans

net_trans = UNetTrans(options=train_options)
net_trans

UNetTrans(
  (input_block): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(24, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (contract_blocks): ModuleList(
    (0): ContractingBlock(
      (contract_block): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (double_conv): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [9]:
# check whether everything works, for loading a same size model
weights_old = net.state_dict()

weights_new = net_trans.state_dict()

for name, tensor in weights_old.items():
    if len(tensor.shape) > 1:
        print(f"{name} and {tensor.shape}")
        print(f"{name} and {weights_new[name].shape}")
        if tensor.shape != weights_new[name].shape:
            print('Shape mismatch!')
            print(f"SHAPES: {tensor.shape} and {weights_new[name].shape}")
        elif not torch.allclose(tensor, weights_new[name]):
            print('Weights mismatch!')
            print(f"WEIGHTS old: {tensor[0,0,:,:]} \nWEIGHTS new: {weights_new[name][0,0,:,:]}")

input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
contract_blocks.2.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
contra

In [20]:
# check whether everything works, for loading a smaller model
train_options['unet_conv_filters'] = [16, 32, 32, 32, 32]
train_options['transfer_model_architecture']['unet_conv_filters'] = [16, 32, 32, 32]
net_trans = UNetTrans(options=train_options)

weights_new = net_trans.state_dict()

for name, tensor in weights_old.items():
    if len(tensor.shape) > 1:
        # shift in the names of the expanding layer of the larger net
        new_name = name
        if 'expand_blocks' in name:
            no_layer = int(name.split('.')[1])
            new_name = f"expand_blocks.{no_layer+1}.{name.split('.')[2]}.{name.split('.')[3]}.{name.split('.')[4]}.{name.split('.')[5]}"
        print(f"OLD: {name} and {tensor.shape}")
        print(f"NEW: {new_name} and {weights_new[new_name].shape}")
        if tensor.shape != weights_new[new_name].shape:
            print('Shape mismatch!')
            print(f"SHAPES: {tensor.shape} and {weights_new[new_name].shape}")
        elif not torch.allclose(tensor, weights_new[new_name]):
            print('Weights mismatch!')
            print(f"WEIGHTS old: {tensor[0,0,:,:]} \nWEIGHTS new: {weights_new[new_name][0,0,:,:]}")


OLD: input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
NEW: input_block.double_conv.0.weight and torch.Size([16, 24, 3, 3])
OLD: input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
NEW: input_block.double_conv.3.weight and torch.Size([16, 16, 3, 3])
OLD: contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
NEW: contract_blocks.0.double_conv.double_conv.0.weight and torch.Size([32, 16, 3, 3])
OLD: contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
NEW: contract_blocks.0.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
OLD: contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
NEW: contract_blocks.1.double_conv.double_conv.0.weight and torch.Size([32, 32, 3, 3])
OLD: contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
NEW: contract_blocks.1.double_conv.double_conv.3.weight and torch.Size([32, 32, 3, 3])
OLD: contract_blocks.2.doubl