In [1]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader

from datasets import EchoNetDataset
from models import EchoNet

In [2]:
video_dir = "/home/tienyu/data/EchoNet-Dynamic/Videos"
target_csv = "/home/tienyu/data/EchoNet-Dynamic/FileList.csv"

batch_size = 16
num_epochs = 50
log_every = 200
lr = 1e-4

In [3]:
# Transformation for frames -> grayscale images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
])

# Input management
trainset = EchoNetDataset(
    video_dir=video_dir,
    target_csv=target_csv,
    split="train",
    transform=transform,
)

valset = EchoNetDataset(
    video_dir=video_dir,
    target_csv=target_csv,
    split="val",
    transform=transform,
)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valloader = DataLoader(valset, batch_size=batch_size)
dataloaders = {"train": trainloader, "val": valloader}

Fetching masks: 100%|██████████| 7419/7419 [00:23<00:00, 312.62it/s]
Fetching masks: 100%|██████████| 1279/1279 [00:04<00:00, 295.26it/s]


In [4]:
# from torchsummary import summary
# summary(model, (1, 112, 112))

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
model = EchoNet(device)
model = model.to('cuda')

In [7]:
# Metric
criterion_bce = nn.BCEWithLogitsLoss()
criterion_mse = nn.MSELoss()
# Optimization Setting
optimizer = optim.SGD(model.parameters(),
                      lr=lr,
                      momentum=0.9,
                      weight_decay=1e-4)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

In [8]:
best_val_loss = torch.finfo(torch.float32).max
N_val = len(valloader.dataset)

iterations = 0
since = time.time()
for epoch in range(num_epochs):
    print("Epoch {}/{}".format(epoch, num_epochs - 1))
    print("-" * 20)

    for phase in ("train", "val"):
        running_loss_seg = 0.0
        running_loss_volume = 0.0
        running_loss_ef = 0.0
        within_epoch_interations = 0

        if phase == "train":
            model.train()
        else:
            model.eval()
            print()

        for inputs, labels, masks in dataloaders[phase]:
            #             video_tensor, video_edv, video_esv = [v.to(device) for v in inputs]
            #             ef, edv, esv = [l.to(device) for l in labels]
            #             mask_edv, mask_esv = [m.to(device) for m in masks]
            video_tensor = inputs[0].to(device)
            input_frames = torch.cat(inputs[1:]).to(device)
            volumes = torch.cat(labels[1:]).unsqueeze(1).float().to(device)
            masks = torch.cat(masks).to(device)
            efs = labels[0].unsqueeze(1).float().to(device)

            if phase == "train":
                iterations += 1
                within_epoch_interations += 1

            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):

                masks_pred, volumes_pred = model(input_frames, goal='mask&volume')
                # Compute losses
                loss_seg = criterion_bce(masks_pred, masks)
                loss_volume = criterion_mse(volumes_pred, volumes)

                efs_pred = model(video_tensor, goal='ef')
                loss_ef = criterion_mse(efs_pred, efs)

                loss_mse = loss_volume + loss_ef

                if phase == "train":
                    loss_seg.backward(retain_graph=True)
                    loss_mse.backward()
                    optimizer.step()

            running_loss_seg += loss_seg.item() * 2 * batch_size
            running_loss_volume += loss_volume.item() * 2 * batch_size
            running_loss_ef += loss_ef.item() * batch_size

            if not iterations % log_every and phase == "train":
                iter_elapsed = time.time() - since
                print((
                    f"[{iter_elapsed//60:>3.0f}m {iter_elapsed%60:2.0f}s] "
                    f"Iteration: {iterations:>4.0f} | "
                    f"{phase.title()} | "
                    f"Segmentation BCE Loss: {running_loss_seg/(within_epoch_interations * 2 * batch_size):.5f} "
                    f"MSE Loss(v): {running_loss_volume/(within_epoch_interations * 2 * batch_size):.5f} "
                    f"MSE Loss(ef): {running_loss_ef/(within_epoch_interations*batch_size):.5f}"
                ))

        if phase == "val":
            val_loss = ((running_loss_seg / (N_val * 2)) + 
                        (running_loss_volume / (N_val * 2)) +
                        (running_loss_ef / N_val))
            print((
                f"Total {phase.title()} loss: {val_loss:.5f} | "
                f"Segmentation BCE Loss: {running_loss_seg/(N_val * 2):.5f} "
                f"MSE Loss(v): {running_loss_volume/(N_val * 2):.5f} "
                f"MSE Loss(ef): {running_loss_ef/(N_val):.5f}"
            ))
            if val_loss < best_val_loss:
                torch.save(model.state_dict(),
                           "checkpoints/best_checkpoint.pt")
                best_val_loss = val_loss

        else:
            scheduler.step()

    print()
time_elapsed = time.time() - since
print(
    f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
print((f"Best validation loss: {best_val_loss:.5f}\n"))

Epoch 0/49
--------------------
[  5m 36s] Iteration:  200 | Train | Segmentation BCE Loss: 0.52705 MSE Loss(v): 2136.04042 MSE Loss(ef): 493.15766
[ 11m  1s] Iteration:  400 | Train | Segmentation BCE Loss: 0.41071 MSE Loss(v): 1753.77086 MSE Loss(ef): 331.33429

Total Val loss: 7290.07325Segmentation BCE Loss: 0.22387 MSE Loss(v): 7139.01777 MSE Loss(ef): 150.83161

Epoch 1/49
--------------------
[ 18m 24s] Iteration:  600 | Train | Segmentation BCE Loss: 0.20482 MSE Loss(v): 1330.87580 MSE Loss(ef): 181.99184


KeyboardInterrupt: 