In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F

import numpy as np
import tqdm

from machine_perception.models.stm.model import STM, load_stm_state_dict
from machine_perception.datasets.stm_dataset import MoseStmDataset

In [2]:
DATA_ROOT = (
    R"D:\Documents\University\MSc\Sem1\MachinePerception\machine-perception\data"
)
WEIGHTS_PATH = R"D:\Documents\University\MSc\Sem1\MachinePerception\machine-perception\resources\stm\STM_weights.pth"

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"{DEVICE = }")

DEVICE = device(type='cuda')


In [None]:
def resize_frames_and_masks(
    frames_t: torch.Tensor, masks_t: torch.Tensor, new_size: tuple[int, int]
) -> tuple[torch.Tensor, torch.Tensor]:
    return (
        F.interpolate(frames_t[0], size=new_size, mode="bilinear").unsqueeze(dim=0),
        F.interpolate(masks_t[0], size=new_size, mode="nearest").unsqueeze(dim=0),
    )


def fine_tune_stm(
    stm: STM,
    dataset: MoseStmDataset,
    weights_save_dir: str | Path,
    epochs: int = 1,
    n_frames_to_sample: int = 3,
    device: str = "cpu",
):
    weights_save_dir = Path(weights_save_dir)
    weights_save_dir.mkdir(exist_ok=True)

    stm.train()

    criterion = nn.CrossEntropyLoss()
    criterion.to(device)
    optimizer = torch.optim.Adam(stm.parameters(), lr=1e-05)

    idxs = list(range(0, len(dataset)))
    for epoch in range(epochs):
        loss_accumulated = 0
        tqdm_obj = tqdm.tqdm(idxs, desc=f"Epoch {epoch}")
        for idx in tqdm_obj:
            torch.cuda.empty_cache()
            frames_t, masks_t, num_objects_t, info = dataset[idx]
            frames_t = frames_t.unsqueeze(dim=0)
            masks_t = masks_t.unsqueeze(dim=0)

            num_frames = info["num_frames"]

            frame_idxs = np.random.choice(
                num_frames, min(n_frames_to_sample, num_frames), replace=False
            )
            frame_idxs.sort()

            frames_t, masks_t = frames_t[:, :, frame_idxs], masks_t[:, :, frame_idxs]
            frames_t, masks_t = resize_frames_and_masks(
                frames_t, masks_t, new_size=(384, 384)
            )
            frames_t, masks_t = frames_t.to(device), masks_t.to(device)

            optimizer.zero_grad()
            total_loss = 0
            prev_keys, prev_values = None, None
            for t in range(len(frame_idxs)):
                if prev_keys is None and prev_values is None:
                    prev_keys, prev_values = stm(
                        frames_t[:, :, t], masks_t[:, :, t], num_objects_t
                    )
                    continue

                # key, value = stm(frames_t[:, :, t-1], masks_t[:, :, t-1], num_objects_t)
                logit = stm(frames_t[:, :, t], prev_keys, prev_values, num_objects_t)
                label = torch.argmax(masks_t[:, :, t], dim=1)
                loss = criterion(logit, label)
                total_loss += loss

                key, value = stm(frames_t[:, :, t], masks_t[:, :, t], num_objects_t)
                prev_keys = torch.cat([prev_keys, key], dim=3)
                prev_values = torch.cat([prev_values, value], dim=3)

            loss_accumulated += total_loss.item()
            # tqdm_obj.set_description(
            #     f"Epoch {epoch}. Average loss over {idx + 1}: {loss_accumulated / (idx + 1)}"
            # )
            tqdm_obj.set_postfix_str(
                f"Average loss over {idx + 1}: {loss_accumulated / (idx + 1)}"
            )
            # print(
            #     f"Average loss over {n_frames_to_sample - 1} frames: {total_loss.item() / (n_frames_to_sample - 1)}"
            # )

            total_loss.backward()
            optimizer.step()

        torch.save(stm.state_dict(), weights_save_dir / f"stm_epoch_{epoch}.pth")

In [13]:
dataset_train = MoseStmDataset(
    DATA_ROOT,
    imset="meta_train_split.json",
    single_object=False,
)
indices = list(range(len(dataset_train)))
dataset_train_subset = data.Subset(dataset_train, indices)
# dataloader_test = data.DataLoader(
#     dataset_test, batch_size=1, shuffle=True, num_workers=0, pin_memory=True
# )

In [14]:
stm = STM()
stm.load_state_dict(load_stm_state_dict(WEIGHTS_PATH))
stm.to(DEVICE)



STM(
  (Encoder_M): Encoder_M(
    (conv1_m): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (conv1_o): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (res2): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias

In [None]:
weights_save_dir = R"D:\Documents\University\MSc\Sem1\MachinePerception\machine-perception\resources\stm\trained_weights"
fine_tune_stm(
    stm,
    dataset_train_subset,
    weights_save_dir=weights_save_dir,
    epochs=3,
    n_frames_to_sample=5,
    device=DEVICE,
)

Epoch 0:   0%|          | 0/1205 [00:00<?, ?it/s]

: 