In [3]:
import torch
import torch.nn as nn
from src.models.frame_detector import Transformer, YNetEncoder, SimpleDenseNet
from src.models.fuvai import YNet
from src.datasets.acouslic_dataset import AcouslicDatasetFull
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import matplotlib.pyplot as plt
from monai.transforms import (
    Compose,
    ScaleIntensity,
    RandGaussianNoise,
    RandGaussianSmooth,
    EnsureType
)
from torch.utils.data import DataLoader, default_collate

this_path = Path().resolve()
# data_path = this_path.parent / 'data/preprocessed'
data_path = this_path.parent / 'data/acouslic-ai-train-set'
assert data_path.exists()

In [4]:
# HP
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 2
hidden_dim = 768

# dataset
from sklearn.model_selection import KFold

metadata_path = data_path / 'circumferences/fetal_abdominal_circumferences_per_sweep.csv'
df = pd.read_csv(metadata_path)
kf = KFold(n_splits=5, shuffle=True, random_state=0)
for fold_n, (train_ids, val_ids) in enumerate(kf.split(df.subject_id.unique())):
    # print(train_ids, val_ids)
    print(len(train_ids), len(val_ids))
    break

preproc_transforms = Compose([
                    ScaleIntensity(),
                    EnsureType()
                ])
frame_transforms = Compose([
            RandGaussianSmooth(prob=0.1),
            RandGaussianNoise(prob=0.5),
            ])
train_dataset = AcouslicDatasetFull(metadata_path=metadata_path,
                                    preprocess_transforms=preproc_transforms,
                                    subject_ids=train_ids)

def my_collate_fn(batch):
    return batch
train_dl = DataLoader(train_dataset,
                      batch_size=batch_size,
                      shuffle=False,
                      num_workers=2,
                      collate_fn=my_collate_fn)

224 57


In [5]:

# # sample = next(iter(train_dl))
# for idx, sample in enumerate(train_dl):
#     print(type(sample), sample[0]['image'].shape)
#     if idx == 10:
#         break
# # print(sample['image'].shape, sample['labels'].shape)
# # image = sample['image']
# # image = image.unsqueeze(2).to(DEVICE)
# # print(image.shape)


In [6]:
# pretrained encoder
ckpt_path = this_path.parent / 'data/fuvai_weights.pt'
pretrained_model = YNet(1, 64, 1)
ckpt = torch.load(ckpt_path)
pretrained_model.load_state_dict(ckpt)
# pretrained_model.to(DEVICE)

encoder = YNetEncoder(pretrained_model=pretrained_model)
encoder.eval()
encoder.to(DEVICE)

proj = nn.Linear(encoder.out_channels,
                              hidden_dim)
proj.to(DEVICE)

# transformer
transformer = Transformer(hidden_dim=hidden_dim)
transformer.to(DEVICE)


# input = torch.randn(1, 840, 1, 256, 256).to(DEVICE)



Transformer(
  (model): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (pos_embed): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [7]:
for idx, samples in enumerate(train_dl):
    print(f'Batch index {idx}')
    encodings = []
    labels = []
    for sample in samples:
        image = sample['image'].unsqueeze(1).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            output = encoder(image)
            print(output.shape)
        output = proj(output)
        print(output.shape)
        encodings.append(output)
        labels.append(sample['labels'])
    # encodings = pad_sequence(encodings, batch_first=True)
    # print(encodings.shape)
    # encodings, mask = transformer(encodings)
    break

Batch index 0
torch.Size([796, 512])
torch.Size([796, 768])
torch.Size([679, 512])
torch.Size([679, 768])


In [8]:
encodings_trans, mask_trans = transformer(encodings)


In [9]:
frames = encodings_trans[~mask_trans, :]
print(frames.shape)

torch.Size([1475, 768])


In [10]:
labels = torch.cat(labels)

In [11]:
classifier = SimpleDenseNet(input_size=hidden_dim,
                                output_size=3)
classifier.to(DEVICE)

logits = classifier(frames)

In [12]:
logits.shape

torch.Size([1475, 3])

In [6]:
encodings_pad = pad_sequence(encodings, batch_first=True)

def create_mask(src, pad_value=0):
    src_seq_len = src.shape[1]
    src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
    # src_padding_mask = (src == padding_value)
    src_padding_mask = torch.any(src == pad_value, dim=(2))

   
    return src_mask, src_padding_mask

src_mask, src_padding_mask = create_mask(encodings_pad)
print(src_mask.shape, src_padding_mask.shape)
# print(src_padding_mask.reshape(encodings.shape[0], encodings.shape[1]).shape)


torch.Size([796, 796]) torch.Size([2, 796])


In [8]:
src_padding_mask[1]

metatensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, Fal

In [22]:
labels

torch.Size([1475])