In [1]:
import torch
from torchsummary import summary
from collections import OrderedDict
from torchvision.models import VGG16_Weights
from SegNetModel import SegNet
from UNetModel import UNet
import os

In [2]:
import sys
sys.path.append('/home/jupyter/work/resources/')

In [3]:
model = UNet()

In [4]:
summary(model, (3,224,224))

### Loading pre-trained weights(VGG16) into SegNet encoder

In [5]:
MODEL_WEIGHTS_DIR = './models/weights/'

In [6]:
def state_dict_name_modifier(state_dict: OrderedDict, model_state_dict: OrderedDict) -> OrderedDict:
    
    # Check which layer name pairs with my model layer name
    name_map = {}
    weight_name_size = [(k, v.shape) for k, v in state_dict.items()] 
    my_model_name_size = [(k, v.shape) for k, v in model_state_dict.items()]
    n, m = len(weight_name_size), len(my_model_name_size)
    i, j = 0, 0
    while i < n and j < m:
        if weight_name_size[i][0].startswith('classifier'): break
        while weight_name_size[i][1] != my_model_name_size[j][1]:
            j += 1    
        name_map[weight_name_size[i][0]] = my_model_name_size[j][0]
        i += 1
        
    # Rename pretraind layers
    renamed_state_dict = OrderedDict()
    for key, val in state_dict.items():
        if key.startswith('classifier'): break
        renamed_state_dict[name_map[key]] = val
    return renamed_state_dict

In [7]:
weights_path = MODEL_WEIGHTS_DIR + 'pretrained_encoder_weights_DEFAULT.pt'

In [11]:
# Skip this part if you already have renamed and saved model weights

if not os.path.isfile(weights_path):
    if not os.path.exists(MODEL_WEIGHTS_DIR):
        os.makedirs(MODEL_WEIGHTS_DIR)
    weights = VGG16_Weights.IMAGENET1K_V1
    state_dict = state_dict_name_modifier(torch.hub.load_state_dict_from_url(VGG16_Weights.IMAGENET1K_V1.url), model.encoder.state_dict())
    torch.save(state_dict, weights_path)

In [23]:
state_dict = torch.load(weights_path)

In [12]:
model.encoder.load_state_dict(state_dict, strict = False)

In [86]:
from torchvision import transforms
transform_train = transforms.Compose([
    transforms.Resize([224, 224]),
    #transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406, 0.5), (0.229, 0.224, 0.225, 0.5)),
    transforms.RandomRotation(random.randint(0, 180))])

In [46]:
from PIL import Image

In [48]:
image = Image.open('./datasets/segNet/images/1.jpg')
mask = Image.open('./datasets/segNet/mask/1.png')

In [60]:
mask

In [130]:
m = torch.as_tensor(np.array(mask))

In [131]:
m.max()

In [64]:
tr = transforms.ToTensor()

In [65]:
im = tr(image)

In [66]:
ms = tr(mask)

In [133]:
transforms.ToTensor(mask)

In [123]:
ms8 = tr(Image.open('./datasets/segNet/mask/2.png'))

In [127]:
image8 = tr(Image.open('./datasets/segNet/images/8.jpg'))

In [128]:
image8.max()

In [124]:
ms8.max()

In [68]:
tmp = transforms.Resize([224, 224])

In [71]:
tr(tmp(image)).shape

In [67]:
print(im.shape, ms.shape)

In [73]:
im1 = im.unsqueeze(0)

In [74]:
ms1 = ms.unsqueeze(0)
ms1.shape

In [85]:
ms

In [75]:
im1.shape

In [78]:
both = torch.cat((im, ms), dim = 0)

In [87]:
res = transform_train(both)

In [100]:
im2, ms2 = torch.tensor_split(res,[3], dim = 0)

In [92]:
import numpy as np

In [101]:
im2.shape

In [113]:
def conv_to_img(tensor: torch.tensor) -> np.array:
    """Convert image to display by pyplot."""
    img = tensor.to('cpu').clone().detach()
    img = img.numpy().squeeze()
    #img = img.transpose(1, 2, 0)
    #img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    img = img.clip(0, 1)
    return img


In [105]:
pic = conv_to_img(im2)

In [114]:
picm = conv_to_img(ms2)

In [108]:
import matplotlib.pyplot as plt

In [109]:
plt.imshow(pic)

In [116]:
plt.imshow(picm)

In [121]:
ms2.max()

In [58]:
transform_train(image, mask)

In [54]:
import torchvision
torchvision.__version__

In [56]:
from torchvision.transforms import v2