In [1]:
import sys
sys.path.insert(0, "../")
import torch
from spot_master.unet.data import (
    FISHSpotsDataset, RandomHorizontalFlip,
    RandomRotation, ToTensorWrapper,
)
from spot_master.unet.model import UNet
from spot_master.unet.utils import DiceLoss, RMSELoss
from spot_master.unet.train import train
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = Compose([
    RandomHorizontalFlip(),
    RandomRotation(),
    ToTensorWrapper(),
])

In [3]:
train_dataset = FISHSpotsDataset(
    meta_csv="meta_train.csv", root_dir="../FISH_spots",
    transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)

test_dataset = FISHSpotsDataset(
    meta_csv="meta_test.csv", root_dir="../FISH_spots")
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(1, 1, 4).to(device)
# for fine-tuning, if not fine-tuning, comment out the following line
model.load_state_dict(torch.load("./best_unet_model.pth"))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
rmse_loss = RMSELoss()
dice_loss = DiceLoss()

def criterion(pred, target):
    loss_dice = dice_loss(pred, target)
    loss_rmse = rmse_loss(pred, target)
    return 0.6 * loss_dice + 0.4 * loss_rmse

In [5]:
# TensorBoard
writer = SummaryWriter("runs/unet_training")

In [6]:
train(
    model, optimizer, criterion, writer, device,
    train_loader, test_loader,
    "best_unet_model_after_fine_tuning.pth", num_epochs=10
)

Epoch: 1/10, Batch: 1/269, Loss: 0.1395
Epoch: 1/10, Batch: 11/269, Loss: 0.1484
Epoch: 1/10, Batch: 21/269, Loss: 0.1272
Epoch: 1/10, Batch: 31/269, Loss: 0.1438
Epoch: 1/10, Batch: 41/269, Loss: 0.1348
Epoch: 1/10, Batch: 51/269, Loss: 0.1287
Epoch: 1/10, Batch: 61/269, Loss: 0.1383
Epoch: 1/10, Batch: 71/269, Loss: 0.1362
Epoch: 1/10, Batch: 81/269, Loss: 0.1298
Epoch: 1/10, Batch: 91/269, Loss: 0.0981
Epoch: 1/10, Batch: 101/269, Loss: 0.1164
Epoch: 1/10, Batch: 111/269, Loss: 0.1220
Epoch: 1/10, Batch: 121/269, Loss: 0.1541
Epoch: 1/10, Batch: 131/269, Loss: 0.1328
Epoch: 1/10, Batch: 141/269, Loss: 0.1376
Epoch: 1/10, Batch: 151/269, Loss: 0.1161
Epoch: 1/10, Batch: 161/269, Loss: 0.1476
Epoch: 1/10, Batch: 171/269, Loss: 0.1155
Epoch: 1/10, Batch: 181/269, Loss: 0.1079
Epoch: 1/10, Batch: 191/269, Loss: 0.1289
Epoch: 1/10, Batch: 201/269, Loss: 0.1255
Epoch: 1/10, Batch: 211/269, Loss: 0.1311
Epoch: 1/10, Batch: 221/269, Loss: 0.1232
Epoch: 1/10, Batch: 231/269, Loss: 0.1267
Epo