In [1]:
from comet_ml import Experiment
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import models
import argparse
from helper import *
torch.cuda.set_device(1)

In [2]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)

classes = ['sky', 'building', 'pole', 'road', 'pavement', 'tree', 'signsymbol', 'fence', 'car', 'pedestrian', 'bicyclist', 'unlabelled']
DATA_DIR = '../data/CamVid/'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.41189489566336, 0.4251328133025, 0.4326707089857], std = [0.27413549931506, 0.28506257482912, 0.28284674400252])
])

train_dataset = CamVid(
    x_train_dir,
    y_train_dir,
    classes = classes,
    transform = transform
)

valid_dataset = CamVid(
    x_valid_dir,
    y_valid_dir,
    classes = classes, 
    transform = transform
)

trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)
valloader = DataLoader(valid_dataset, batch_size = 1, shuffle = False)

In [7]:
student = models.unet.Unet('resnet26', classes = 12, encoder_weights = None).cuda()
sf = SaveFeatures(student.encoder.relu)
teacher = models.unet.Unet('resnet34', classes = 12, encoder_weights = None).cuda()
teacher.load_state_dict(torch.load('../saved_models/resnet34/pretrained_0.pt'))
sf2 = SaveFeatures(teacher.encoder.relu)

In [8]:
for name, param in student.named_parameters() : 
    print(name, param.shape, param.requires_grad)

encoder.conv1.weight torch.Size([64, 3, 7, 7]) True
encoder.bn1.weight torch.Size([64]) True
encoder.bn1.bias torch.Size([64]) True
encoder.layer1.0.conv1.weight torch.Size([64, 64, 3, 3]) True
encoder.layer1.0.bn1.weight torch.Size([64]) True
encoder.layer1.0.bn1.bias torch.Size([64]) True
encoder.layer1.0.conv2.weight torch.Size([64, 64, 3, 3]) True
encoder.layer1.0.bn2.weight torch.Size([64]) True
encoder.layer1.0.bn2.bias torch.Size([64]) True
encoder.layer1.1.conv1.weight torch.Size([64, 64, 3, 3]) True
encoder.layer1.1.bn1.weight torch.Size([64]) True
encoder.layer1.1.bn1.bias torch.Size([64]) True
encoder.layer1.1.conv2.weight torch.Size([64, 64, 3, 3]) True
encoder.layer1.1.bn2.weight torch.Size([64]) True
encoder.layer1.1.bn2.bias torch.Size([64]) True
encoder.layer1.2.conv1.weight torch.Size([64, 64, 3, 3]) True
encoder.layer1.2.bn1.weight torch.Size([64]) True
encoder.layer1.2.bn1.bias torch.Size([64]) True
encoder.layer1.2.conv2.weight torch.Size([64, 64, 3, 3]) True
encode

In [21]:
def unfreeze(model, stage) : 
    if stage == 0 :
        for name, param in model.named_parameters() :
            param.requires_grad = False
            if name.startswith('encoder.conv') or name.startswith('encoder.bn') : 
                param.requires_grad = True
    
    elif stage > 0 and stage < 5 : 
        for name, param in model.named_parameters() :
            param.requires_grad = False
            if name.startswith('encoder.layer' + str(stage)) : 
                param.requires_grad = True
    
    elif stage > 4 and stage < 10 : 
        for name, param in model.named_parameters() :
            param.requires_grad = False
            if name.startswith('decoder.blocks.' + str(stage - 5)) : 
                param.requires_grad = True
    
    elif stage == 10 : 
        for name, param in model.named_parameters() :
            param.requires_grad = False
            if name.startswith('segmentation') :
                param.requires_grad = True
    
    else :
        print('Invalid stage input: only integers from 0 to 10 are valid')
    
    return model

student = models.unet.Unet('resnet26', classes = 12, encoder_weights = None).cuda()
stage = 0
student = unfreeze(student, stage)
for name, param in student.named_parameters() : 
    print(name, param.shape, param.requires_grad)

encoder.conv1.weight torch.Size([64, 3, 7, 7]) True
encoder.bn1.weight torch.Size([64]) True
encoder.bn1.bias torch.Size([64]) True
encoder.layer1.0.conv1.weight torch.Size([64, 64, 3, 3]) False
encoder.layer1.0.bn1.weight torch.Size([64]) False
encoder.layer1.0.bn1.bias torch.Size([64]) False
encoder.layer1.0.conv2.weight torch.Size([64, 64, 3, 3]) False
encoder.layer1.0.bn2.weight torch.Size([64]) False
encoder.layer1.0.bn2.bias torch.Size([64]) False
encoder.layer1.1.conv1.weight torch.Size([64, 64, 3, 3]) False
encoder.layer1.1.bn1.weight torch.Size([64]) False
encoder.layer1.1.bn1.bias torch.Size([64]) False
encoder.layer1.1.conv2.weight torch.Size([64, 64, 3, 3]) False
encoder.layer1.1.bn2.weight torch.Size([64]) False
encoder.layer1.1.bn2.bias torch.Size([64]) False
encoder.layer1.2.conv1.weight torch.Size([64, 64, 3, 3]) False
encoder.layer1.2.bn1.weight torch.Size([64]) False
encoder.layer1.2.bn1.bias torch.Size([64]) False
encoder.layer1.2.conv2.weight torch.Size([64, 64, 3, 