In [1]:
import torch
import torch.nn as nn
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from src.models.fuvai import YNet
import numpy as np

this_path = Path().resolve()
data_path = this_path.parent.parent / 'acouslic-ai-train-set'

In [9]:
new_md = Path('/home/alejandro/repos/ACUSLIC_chal/data/preprocessed/full-slice_256x256_only-sweeps/metadata.csv')
new_md = pd.read_csv(new_md)
new_md.plane_type.unique()

array([0, 1, 2])

In [2]:
from src.models.fuvai import YNet

class YNetEncoder(nn.Module):
    def __init__(self, pretrained_model):
        super(YNetEncoder, self).__init__()
        
        self.down_conv1 = pretrained_model.down_conv1
        self.down_conv2 = pretrained_model.down_conv2
        self.down_conv3 = pretrained_model.down_conv3
        self.down_conv4 = pretrained_model.down_conv4
        self.max_pool_2x2 = pretrained_model.max_pool_2x2
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
    def forward(self, x):
        # Encoder part
        x = torch.unbind(x, dim=1)
        data = []
        for item in x:
            x1 = self.down_conv1(item)  
            x2 = self.down_conv2(self.max_pool_2x2(x1))  
            x3 = self.down_conv3(self.max_pool_2x2(x2))  
            x4 = self.down_conv4(self.max_pool_2x2(x3))  
            features = self.max_pool_2x2(x4)  
            data.append(features.unsqueeze(0))
        data = torch.cat(data, dim=0)
        print(data.shape)
        data = self.avgpool(data)
        print(data.shape)

        return torch.flatten(data, -3)

In [3]:
ckpt_path = data_path / 'fuvai_weights.pt'
model = YNet(1, 64, 1)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt)

encoder = YNetEncoder(pretrained_model=model)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
encoder = encoder.to(device)
encoder.eval()
input = torch.randn(10, 1, 1, 256, 256).to(device)

In [4]:
with torch.no_grad():
    output = encoder(input)
    print(output.shape)

torch.Size([1, 10, 512, 16, 16])
torch.Size([1, 10, 512, 1, 1])
torch.Size([1, 10, 512])


In [7]:
encoder.down_conv4[3].weight.shape

torch.Size([512, 512, 3, 3])

In [58]:
output.squeeze().shape

torch.Size([10, 512])

### Dev

In [3]:
encoder = YNetEncoder(1, 64, 1)
encoder.down_conv1[0] # copying a param with shape torch.Size([64, 1, 3, 3]) from checkpoint

Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [21]:
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# encoder.to(device)
avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
avgpool.to(device)
encoder.eval()
input = torch.randn(10, 1, 1, 256, 256).to(device) # 224 1, 25, 1, 256, 256
output = encoder(input)
print(output.shape)
output = avgpool(output)
print(output.shape)
output = torch.flatten(output, 1)
print(output.shape)

torch.Size([1, 5120])


ValueError: Input dimension should be at least 3

In [5]:
output

tensor([[0.0083, 0.0033, 0.0023,  ..., 0.0102, 0.0146, 0.0041],
        [0.0082, 0.0032, 0.0023,  ..., 0.0102, 0.0146, 0.0041],
        [0.0083, 0.0032, 0.0022,  ..., 0.0103, 0.0146, 0.0041],
        ...,
        [0.0082, 0.0032, 0.0022,  ..., 0.0103, 0.0147, 0.0041],
        [0.0083, 0.0033, 0.0023,  ..., 0.0103, 0.0146, 0.0041],
        [0.0083, 0.0033, 0.0023,  ..., 0.0102, 0.0146, 0.0041]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [6]:
x = torch.unbind(input, dim=1)
print(len(x), x[0].shape)
x = encoder.down_conv1(x[0])
print(x.shape)

2 torch.Size([10, 1, 224, 224])
torch.Size([10, 64, 224, 224])


In [3]:
ckpt_path = data_path / 'fuvai_weights.pt'
ckpt = torch.load(ckpt_path)
for k, v in ckpt.items():
    print(k)

down_conv1.0.weight
down_conv1.0.bias
down_conv1.1.weight
down_conv1.1.bias
down_conv1.1.running_mean
down_conv1.1.running_var
down_conv1.1.num_batches_tracked
down_conv1.3.weight
down_conv1.3.bias
down_conv1.4.weight
down_conv1.4.bias
down_conv1.4.running_mean
down_conv1.4.running_var
down_conv1.4.num_batches_tracked
down_conv2.0.weight
down_conv2.0.bias
down_conv2.1.weight
down_conv2.1.bias
down_conv2.1.running_mean
down_conv2.1.running_var
down_conv2.1.num_batches_tracked
down_conv2.3.weight
down_conv2.3.bias
down_conv2.4.weight
down_conv2.4.bias
down_conv2.4.running_mean
down_conv2.4.running_var
down_conv2.4.num_batches_tracked
down_conv3.0.weight
down_conv3.0.bias
down_conv3.1.weight
down_conv3.1.bias
down_conv3.1.running_mean
down_conv3.1.running_var
down_conv3.1.num_batches_tracked
down_conv3.3.weight
down_conv3.3.bias
down_conv3.4.weight
down_conv3.4.bias
down_conv3.4.running_mean
down_conv3.4.running_var
down_conv3.4.num_batches_tracked
down_conv4.0.weight
down_conv4.0.bias
do

In [4]:
model = YNet(1, 64, 1) # 1, 64, 1
model.load_state_dict(ckpt)

<All keys matched successfully>

In [6]:
layer = model.down_conv1