In [1]:
import sys
sys.path.insert(0, "../")
from spot_master.pos_reg.data import FISHSpotsDataset
import torch
import numpy as np
from torch import Tensor
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.ioff()
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
max_spots = 3000

train_dataset = FISHSpotsDataset(
    meta_csv="meta_train.csv", root_dir="../FISH_spots",
    max_spots=max_spots)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)

test_dataset = FISHSpotsDataset(
    meta_csv="meta_test.csv", root_dir="../FISH_spots",
    max_spots=max_spots)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)

In [3]:
class PosRegNet(nn.Module):
    def __init__(
            self, backbone: nn.Module,
            backbone_out_features: int,
            backbone_in_channels: int,
            input_channels: int = 1,
            n_pos: int = 8000, pos_dim: int = 2,
            hidden_dim: int = 10000,
            img_shape: tuple = (512, 512)):
        super(PosRegNet, self).__init__()
        self.n_pos = n_pos
        self.pos_dim = pos_dim
        self.backbone = backbone
        self.input_conv = nn.Conv2d(
            in_channels=input_channels,
            out_channels=backbone_in_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.img_shape = torch.tensor(
            img_shape, dtype=torch.float32)
        self.fc1 = nn.Linear(
            in_features=backbone_out_features,
            out_features=hidden_dim,
        )
        self.fc2 = nn.Linear(
            in_features=hidden_dim,
            out_features=n_pos * pos_dim
        )

    def to(self, device):
        self.img_shape = self.img_shape.to(device)
        return super().to(device)

    def forward(self, x: Tensor):
        x = self.input_conv(x)
        x = self.backbone(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = x.view(-1, self.n_pos, self.pos_dim)
        x = torch.sigmoid(x)
        x = x * self.img_shape
        return x


class PosRegResNet18(PosRegNet):
    def __init__(
            self, n_pos: int = 8000, pos_dim: int = 2,
            input_channels: int = 1,
            img_shape: tuple = (512, 512)):
        from torchvision.models import resnet18
        backbone = resnet18(pretrained=True)
        backbone_out_features = backbone.fc.in_features
        backbone.fc = nn.Identity()
        super().__init__(
            backbone, backbone_out_features,
            backbone.conv1.in_channels,
            input_channels, n_pos, pos_dim,
            n_pos, img_shape)


In [4]:
def nearest_neighbor_matching(pred: Tensor, target: Tensor, n_spots: Tensor):
    loss = 0.0
    n_batch = pred.shape[0]
    for b_id in range(n_batch):
        n = n_spots[b_id]
        t_ = target[b_id, :n, :]
        dist_matrix = torch.cdist(pred, t_)
        min_row_dists, _ = torch.min(dist_matrix, dim=2)
        min_col_dists, _ = torch.min(dist_matrix, dim=1)
        loss_row = torch.mean(min_row_dists)
        loss_col = torch.mean(min_col_dists)
        loss += 0.5 * (loss_row + loss_col)
    loss *= 1.0 / n_batch
    return loss

class PosRegLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target, n_spots):
        loss = nearest_neighbor_matching(pred, target, n_spots)
        return loss


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PosRegResNet18(img_shape=(512, 512), n_pos=max_spots).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = PosRegLoss()



In [9]:
# TensorBoard
writer = SummaryWriter("runs/pos_reg_training")

In [7]:
def plt_to_rgb(fig):
    fig.canvas.draw()

    # Now we can save it to a numpy array.
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    data = data.swapaxes(0, 2)
    return data

def draw_scatter(pos, img):
    pos = pos.detach().cpu().numpy()
    img = img.detach().cpu().numpy()
    fig = plt.figure(figsize=(10, 10))
    plt.imshow(img, cmap="gray")
    plt.scatter(pos[:, 1], pos[:, 0], s=2, c="red")
    plt.xlim(0, img.shape[1])
    plt.ylim(img.shape[0], 0)
    plt.axis("off")
    plt.tight_layout()
    data = plt_to_rgb(fig)
    plt.close(fig)
    return data

In [10]:
# Training loop
num_epochs = 50
best_val_loss = float("inf")
model_save_path = "./best_posreg_model.pth"

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for idx, batch in enumerate(train_loader):
        images = batch["image"].to(device, dtype=torch.float32)
        coords = batch["coordinates"].to(device, dtype=torch.float32)
        n_spots = batch["n_spots"].to(device, dtype=torch.int32)

        optimizer.zero_grad()
        outputs = model(images)
        loss: Tensor = criterion(outputs, coords, n_spots)
        loss.backward()
        optimizer.step()

        if idx % 10 == 0:
            print(
                f"Epoch: {epoch + 1}/{num_epochs}, "
                f"Batch: {idx + 1}/{len(train_loader)}, "
                f"Loss: {loss.item():.4f}"
            )
            writer.add_scalar(
                "Loss/train_batch", loss.item(),
                epoch * len(train_loader) + idx)
            # record images
            writer.add_image(
                "Image/input", images[0], epoch * len(train_loader) + idx)
            writer.add_image(
                "Image/pred", draw_scatter(outputs[0], images[0, 0]),
                epoch * len(train_loader) + idx)
            writer.add_image(
                "Image/target", draw_scatter(coords[0], images[0, 0]),
                epoch * len(train_loader) + idx)

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    writer.add_scalar("Loss/train", epoch_loss, epoch)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            images = batch["image"].to(device, dtype=torch.float32)
            coords = batch["coordinates"].to(device, dtype=torch.float32)
            n_spots = batch["n_spots"].to(device, dtype=torch.int32)

            outputs = model(images)
            loss = criterion(outputs, coords, n_spots)
            val_loss += loss.item()

    val_loss /= len(test_loader)
    writer.add_scalar("Loss/val", val_loss, epoch)

    print(
        f"Epoch {epoch + 1}/{num_epochs}, "
        f"Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}"
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_save_path)
        print(f"Best model saved with Val Loss: {val_loss:.4f}")

writer.close()


Epoch: 1/50, Batch: 1/537, Loss: 21.9894
Epoch: 1/50, Batch: 11/537, Loss: 28.0620
Epoch: 1/50, Batch: 21/537, Loss: 17.5036
Epoch: 1/50, Batch: 31/537, Loss: 29.0947
Epoch: 1/50, Batch: 41/537, Loss: 31.6296
Epoch: 1/50, Batch: 51/537, Loss: 56.1300
Epoch: 1/50, Batch: 61/537, Loss: 17.2997
Epoch: 1/50, Batch: 71/537, Loss: 28.7822
Epoch: 1/50, Batch: 81/537, Loss: 16.0676
Epoch: 1/50, Batch: 91/537, Loss: 17.2594
Epoch: 1/50, Batch: 101/537, Loss: 11.0928
Epoch: 1/50, Batch: 111/537, Loss: 15.8249
Epoch: 1/50, Batch: 121/537, Loss: 15.3391
Epoch: 1/50, Batch: 131/537, Loss: 10.2027
Epoch: 1/50, Batch: 141/537, Loss: 15.8620
Epoch: 1/50, Batch: 151/537, Loss: 9.7377
Epoch: 1/50, Batch: 161/537, Loss: 38.2111
Epoch: 1/50, Batch: 171/537, Loss: 17.2953
Epoch: 1/50, Batch: 181/537, Loss: 71.0254
Epoch: 1/50, Batch: 191/537, Loss: 10.1126
Epoch: 1/50, Batch: 201/537, Loss: 11.6874
Epoch: 1/50, Batch: 211/537, Loss: 27.2994
Epoch: 1/50, Batch: 221/537, Loss: 96.7795
Epoch: 1/50, Batch: 231