In [1]:
import torch
import torch.nn as nn
from src.models.frame_detector import (
    Transformer,
    YNetEncoder,
    SimpleDenseNet,
    build_models,
    create_mask)
from src.models.fuvai import YNet
from src.datasets.acouslic_dataset import AcouslicDatasetFull
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import matplotlib.pyplot as plt
from monai.transforms import (
    Compose,
    ScaleIntensity,
    RandGaussianNoise,
    RandGaussianSmooth,
    EnsureType
)
from torch.utils.data import DataLoader, default_collate
from torchmetrics.classification import Accuracy
from torch.optim import Adam
from tqdm import tqdm
from sklearn.model_selection import KFold


this_path = Path().resolve()
# data_path = this_path.parent / 'data/preprocessed'
data_path = this_path.parent / 'data/acouslic-ai-train-set'
assert data_path.exists()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
probs = torch.tensor([[0.16, 0.26, 0.58],
...                 [0.22, 0.61, 0.17],
...                 [0.71, 0.09, 0.20],
...                 [0.05, 0.82, 0.13]])
probs.argmax(dim=1), probs.shape
torch.cat([probs.argmax(dim=1), probs.argmax(dim=1)]).shape

torch.Size([8])

### Debug shape mismatch

In [4]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_epoch(encoder,
                projector,
                transformer,
                classifier,
                optimizer,
                criterion,
                dataloader):

    projector.train()
    transformer.train()
    classifier.train()

    acc_metric_avg = Accuracy(task='multiclass',
                              num_classes=3,
                              average='macro').to(DEVICE)
    losses = 0
    n_frames = 0
    for samples in dataloader:
        # model fwd
        encodings = []
        labels = []
        for sample in samples:
            image = sample['image'].unsqueeze(1).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                output = encoder(image)
            output = projector(output)
            encodings.append(output)
            labels.append(sample['labels'])

        encodings, masks = transformer(encodings)

        # classifier
        encodings = encodings[~masks, :]
        labels = torch.cat(labels).to(DEVICE)
        logits = classifier(encodings)
        if logits.shape[0] != labels.shape[0]:
            print('Logits and labels shape mismatch')
            print(f'Logits: {logits.shape}, labels: {labels.shape}')
            print(f'S1: {samples[0]["uuid"]}, S2: {samples[1]["uuid"]}')

        optimizer.zero_grad()

        loss = criterion(logits, labels)
        loss.backward()
        preds = nn.functional.softmax(logits, dim=1)
        acc_avg = acc_metric_avg(preds, labels)

        optimizer.step()
        losses += loss.item()
        n_frames += len(labels)

    return losses / n_frames, acc_metric_avg.compute()

NUM_EPOCHS = 20
batch_size = 2
hidden_dim = 768
lr = 1e-4
weights = torch.Tensor([0.3441, 32.8396, 15.6976]).to(DEVICE)
workers = 2
# create dataset
metadata_path = data_path / 'circumferences/fetal_abdominal_circumferences_per_sweep.csv'
df = pd.read_csv(metadata_path)
kf = KFold(n_splits=5, shuffle=True, random_state=0)
for fold_n, (train_ids, val_ids) in enumerate(kf.split(df.subject_id.unique())):
    # print(train_ids, val_ids)
    break

preproc_transforms = Compose([
                    ScaleIntensity()
                ])
# frame_transforms = Compose([
#             RandGaussianSmooth(prob=0.1),
#             RandGaussianNoise(prob=0.5),
#             ])
train_dataset = AcouslicDatasetFull(metadata_path=metadata_path,
                                    preprocess_transforms=preproc_transforms,
                                    subject_ids=train_ids)
val_dataset = AcouslicDatasetFull(metadata_path=metadata_path,
                                    preprocess_transforms=preproc_transforms,
                                    subject_ids=val_ids)

def my_collate_fn(batch):
    return batch

train_dl = DataLoader(train_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=workers,
                        collate_fn=my_collate_fn)
val_dl = DataLoader(val_dataset,
                    batch_size=batch_size,
                    shuffle=False,
                    num_workers=workers,
                    collate_fn=my_collate_fn)
print(f'Dataset sizes: train={len(train_dataset)}, val={len(val_dataset)}')
# create models and move to device
encoder, projector, transformer, classifier = build_models(hidden_dim)
encoder.to(DEVICE)
projector.to(DEVICE)
transformer.to(DEVICE)
classifier.to(DEVICE)
print('Models created and moved to device')

# optimizer
params = list(projector.parameters()) + list(transformer.parameters()) \
    + list(classifier.parameters())
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999))

criterion = nn.CrossEntropyLoss(weight=weights, reduction='mean')



Dataset sizes: train=224, val=57
Models created and moved to device


In [5]:
projector.train()
transformer.train()
classifier.train()

for samples in train_dl:
    # model fwd
    encodings = []
    labels = []
    for sample in samples:
        print(f"Sample: {sample['image'].shape}")
        image = sample['image'].unsqueeze(1).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            output = encoder(image)
        output = projector(output)
        encodings.append(output)
        labels.append(sample['labels'])
    print(f'Before transformers: {samples[0]["image"].shape[0]+samples[1]["image"].shape[0]}')
    encodings_t, masks = transformer(encodings)
    print(f'After transformer, enc shape: {encodings_t.shape}, masks sum: {(~masks).sum()}')
    # classifier
    encodings_t = encodings_t[~masks, :]
    labels = torch.cat(labels).to(DEVICE)
    logits = classifier(encodings_t)
    if logits.shape[0] != labels.shape[0]:
        print('Logits and labels shape mismatch')
        print(f'Logits: {logits.shape}, labels: {labels.shape}')
        print(f'S1: {samples[0]["uuid"]}, S2: {samples[1]["uuid"]}')
        break

Sample: torch.Size([681, 256, 256])
Sample: torch.Size([660, 256, 256])
Before transformers: 1341
After transformer, enc shape: torch.Size([2, 681, 768]), masks sum: 1340
Logits and labels shape mismatch
Logits: torch.Size([1340, 3]), labels: torch.Size([1341])
S1: f1fcabfc-f998-44c7-8420-c7a5ae5aaab7, S2: 046ed03e-4b35-4519-bb5f-cd4b0474a060


In [8]:
encodings[0].shape, encodings[1].shape

(torch.Size([681, 768]), torch.Size([660, 768]))

In [6]:
encodings_pad = pad_sequence(encodings, batch_first=True)
print(encodings_pad.shape)
src_mask, src_padding_mask = create_mask(encodings_pad, 0)
print(src_mask.shape, src_padding_mask.shape)

torch.Size([2, 681, 768])
torch.Size([681, 681]) torch.Size([2, 681])


In [7]:
(~src_padding_mask).sum()

metatensor(1340, device='cuda:0')

In [25]:
exmaskpos = torch.any(encodings_pad != 0.0, dim=2)
print(exmaskpos.shape, exmaskpos.sum())
exmaskneg = torch.all(encodings_pad == 0.0, dim=2) # default
print(exmaskneg.shape, exmaskneg.sum())
print((~exmaskneg).sum())

torch.Size([2, 681]) metatensor(1341, device='cuda:0')
torch.Size([2, 681]) metatensor(21, device='cuda:0')
metatensor(1341, device='cuda:0')


In [18]:
(~torch.any(encodings_pad == 0.0, dim=2)).sum()

metatensor(1340, device='cuda:0')

In [13]:
encodings, masks = transformer(encodings)
print(encodings.shape, masks.shape)

torch.Size([2, 686, 768]) torch.Size([2, 686])


In [14]:
(~masks).sum()

metatensor(1357, device='cuda:0')

In [15]:
encodings = encodings[~masks, :]
print(encodings.shape)

torch.Size([1357, 768])


In [3]:
print('Starting training')
for epoch in range(1, NUM_EPOCHS+1):
    train_loss = train_epoch(encoder,
                                projector,
                                transformer,
                                classifier,
                                optimizer,
                                criterion,
                                train_dl)

Starting training


ERROR:tornado.general:SEND Error: Host unreachable


In [4]:
# HP
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 2
hidden_dim = 768

# dataset
from sklearn.model_selection import KFold

metadata_path = data_path / 'circumferences/fetal_abdominal_circumferences_per_sweep.csv'
df = pd.read_csv(metadata_path)
kf = KFold(n_splits=5, shuffle=True, random_state=0)
for fold_n, (train_ids, val_ids) in enumerate(kf.split(df.subject_id.unique())):
    # print(train_ids, val_ids)
    print(len(train_ids), len(val_ids))
    break

preproc_transforms = Compose([
                    ScaleIntensity(),
                    EnsureType()
                ])
frame_transforms = Compose([
            RandGaussianSmooth(prob=0.1),
            RandGaussianNoise(prob=0.5),
            ])
train_dataset = AcouslicDatasetFull(metadata_path=metadata_path,
                                    preprocess_transforms=preproc_transforms,
                                    subject_ids=train_ids)

def my_collate_fn(batch):
    return batch
train_dl = DataLoader(train_dataset,
                      batch_size=batch_size,
                      shuffle=False,
                      num_workers=2,
                      collate_fn=my_collate_fn)

224 57


In [5]:

# # sample = next(iter(train_dl))
# for idx, sample in enumerate(train_dl):
#     print(type(sample), sample[0]['image'].shape)
#     if idx == 10:
#         break
# # print(sample['image'].shape, sample['labels'].shape)
# # image = sample['image']
# # image = image.unsqueeze(2).to(DEVICE)
# # print(image.shape)


In [6]:
# pretrained encoder
ckpt_path = this_path.parent / 'data/fuvai_weights.pt'
pretrained_model = YNet(1, 64, 1)
ckpt = torch.load(ckpt_path)
pretrained_model.load_state_dict(ckpt)
# pretrained_model.to(DEVICE)

encoder = YNetEncoder(pretrained_model=pretrained_model)
encoder.eval()
encoder.to(DEVICE)

proj = nn.Linear(encoder.out_channels,
                              hidden_dim)
proj.to(DEVICE)

# transformer
transformer = Transformer(hidden_dim=hidden_dim)
transformer.to(DEVICE)


# input = torch.randn(1, 840, 1, 256, 256).to(DEVICE)



Transformer(
  (model): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (pos_embed): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [7]:
for idx, samples in enumerate(train_dl):
    print(f'Batch index {idx}')
    encodings = []
    labels = []
    for sample in samples:
        image = sample['image'].unsqueeze(1).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            output = encoder(image)
            print(output.shape)
        output = proj(output)
        print(output.shape)
        encodings.append(output)
        labels.append(sample['labels'])
    # encodings = pad_sequence(encodings, batch_first=True)
    # print(encodings.shape)
    # encodings, mask = transformer(encodings)
    break

Batch index 0
torch.Size([796, 512])
torch.Size([796, 768])
torch.Size([679, 512])
torch.Size([679, 768])


In [8]:
encodings_trans, mask_trans = transformer(encodings)


In [9]:
frames = encodings_trans[~mask_trans, :]
print(frames.shape)

torch.Size([1475, 768])


In [10]:
labels = torch.cat(labels)

In [11]:
classifier = SimpleDenseNet(input_size=hidden_dim,
                                output_size=3)
classifier.to(DEVICE)

logits = classifier(frames)

In [12]:
logits.shape

torch.Size([1475, 3])

In [6]:
encodings_pad = pad_sequence(encodings, batch_first=True)

def create_mask(src, pad_value=0):
    src_seq_len = src.shape[1]
    src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
    # src_padding_mask = (src == padding_value)
    src_padding_mask = torch.any(src == pad_value, dim=(2))

   
    return src_mask, src_padding_mask

src_mask, src_padding_mask = create_mask(encodings_pad)
print(src_mask.shape, src_padding_mask.shape)
# print(src_padding_mask.reshape(encodings.shape[0], encodings.shape[1]).shape)


torch.Size([796, 796]) torch.Size([2, 796])


In [8]:
src_padding_mask[1]

metatensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, Fal

In [22]:
labels

torch.Size([1475])