In [1]:
! rm -r machine-perception machine_perception
! git clone https://github.com/Ansever/machine-perception.git
! cp -r machine-perception/src/machine_perception .

Cloning into 'machine-perception'...
remote: Enumerating objects: 126, done.[K
remote: Counting objects: 100% (126/126), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 126 (delta 50), reused 105 (delta 32), pack-reused 0 (from 0)[K
Receiving objects: 100% (126/126), 21.30 MiB | 35.35 MiB/s, done.
Resolving deltas: 100% (50/50), done.


In [2]:
! mkdir weights
! wget -O weights/STM_weights.pth "https://www.dropbox.com/s/mtfxdr93xc3q55i/STM_weights.pth?dl=1"

mkdir: cannot create directory ‘weights’: File exists
--2025-06-04 23:31:41--  https://www.dropbox.com/s/mtfxdr93xc3q55i/STM_weights.pth?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312
Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.dropbox.com/scl/fi/2ep00pf3q305sh4g1op4y/STM_weights.pth?rlkey=kjnnk9dl82btknrn1jmac9cd2&dl=1 [following]
--2025-06-04 23:31:42--  https://www.dropbox.com/scl/fi/2ep00pf3q305sh4g1op4y/STM_weights.pth?rlkey=kjnnk9dl82btknrn1jmac9cd2&dl=1
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc010108e4997f518278921c087d.dl.dropboxusercontent.com/cd/0/inline/CrCi0nlULnMMHrdUO6PigMq38GmSUQGrSv6KYfkYusHfi4H4qCQGSYxuILOHd5BL2oSTOtzrhM5CJAij0a8GGfo3PUW5yCjHNEu8LQ0zP_m4SNdCYKk2EO4vPKHWL-PaiYWGc1obZwyDp7F5a0chL7LH/file?dl=1# [following]
--2025-06-04 23:31

In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F

import numpy as np
import tqdm

from machine_perception.models.stm.model import STM, load_stm_state_dict
from machine_perception.datasets.stm_dataset import MoseStmDataset

In [7]:
DATA_ROOT = (
    "/kaggle/input/mose-subset/data_subset"
)
# WEIGHTS_PATH = "/kaggle/working/weights/STM_weights.pth"
WEIGHTS_PATH = "/kaggle/working/weights/finetune/stm_epoch_4.pth"

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print(f"{DEVICE = }")

DEVICE = device(type='cuda')


In [8]:
def fine_tune_stm(
    stm: STM,
    dataset: MoseStmDataset,
    weights_save_dir: str | Path,
    epochs: int = 1,
    device: str = "cpu",
):
    weights_save_dir = Path(weights_save_dir)
    weights_save_dir.mkdir(exist_ok=True)

    stm.train()

    criterion = nn.CrossEntropyLoss()
    # criterion.to("cpu")
    optimizer = torch.optim.Adam(stm.parameters(), lr=1e-05)

    idxs = list(range(0, len(dataset)))
    for epoch in range(epochs):
        loss_accumulated = 0
        tqdm_obj = tqdm.tqdm(idxs, desc=f"Epoch {epoch}")
        for idx in tqdm_obj:
            frames_t, masks_t, num_objects_t, info = dataset[idx]
            if num_objects_t.item() >= 11:
                continue
            frames_t = frames_t.unsqueeze(dim=0)
            masks_t = masks_t.unsqueeze(dim=0)

            # tqdm_obj.set_postfix_str(
            #     f"Dimensions: {frames_t.shape}"
            # )
            
            num_frames = info["num_frames"]

            frames_t, masks_t = frames_t.to(device), masks_t.to(device)

            optimizer.zero_grad()
            total_loss = None
            keys_l, values_l = [], []
            for t in range(frames_t.shape[2]):
                if len(keys_l) == 0 and len(values_l) == 0:
                    key, value = stm(
                        frames_t[:, :, t], masks_t[:, :, t], num_objects_t
                    )
                    keys_l.append(key)
                    values_l.append(value)
                    continue

                prev_keys = torch.cat(keys_l, dim=3)
                prev_values = torch.cat(values_l, dim=3)
                # key, value = stm(frames_t[:, :, t-1], masks_t[:, :, t-1], num_objects_t)
                logit = stm(frames_t[:, :, t], prev_keys, prev_values, num_objects_t)
                label = torch.argmax(masks_t[:, :, t], dim=1)
                loss = criterion(logit, label)
                if total_loss is None:
                    total_loss = loss
                else:
                    total_loss += loss

                key, value = stm(frames_t[:, :, t], masks_t[:, :, t], num_objects_t)
                keys_l.append(key)
                values_l.append(value)


            loss_accumulated += total_loss.item()
            # tqdm_obj.set_description(
            #     f"Epoch {epoch}. Average loss over {idx + 1}: {loss_accumulated / (idx + 1)}"
            # )
            tqdm_obj.set_postfix_str(
                f"Average loss over {idx + 1}: {loss_accumulated / (idx + 1)}"
            )
            # print(
            #     f"Average loss over {n_frames_to_sample - 1} frames: {total_loss.item() / (n_frames_to_sample - 1)}"
            # )

            total_loss.backward()
            optimizer.step()

        torch.save(stm.state_dict(), weights_save_dir / f"stm_epoch_{epoch}.pth")

In [9]:
dataset_train = MoseStmDataset(
    DATA_ROOT,
    imset="meta_train_split.json",
    single_object=False,
    new_size=(384, 384),
    n_frames_to_sample=3
)
indices = list(range(len(dataset_train)))
dataset_train_subset = data.Subset(dataset_train, indices)

In [10]:
stm = STM()
stm.load_state_dict(load_stm_state_dict(WEIGHTS_PATH))
stm = stm.to(DEVICE)
stm = nn.DataParallel(stm, device_ids=[0, 1])



In [11]:
weights_save_dir = "/kaggle/working/weights/finetune2"
fine_tune_stm(
    stm,
    dataset_train_subset,
    weights_save_dir=weights_save_dir,
    epochs=80,
    device=DEVICE,
)

Epoch 0: 100%|██████████| 302/302 [03:57<00:00,  1.27it/s, Average loss over 302: 0.15483287927006148]
Epoch 1: 100%|██████████| 302/302 [03:57<00:00,  1.27it/s, Average loss over 302: 0.16263369665955676]
Epoch 2: 100%|██████████| 302/302 [03:56<00:00,  1.27it/s, Average loss over 302: 0.16017433661485278]
Epoch 3: 100%|██████████| 302/302 [03:56<00:00,  1.28it/s, Average loss over 302: 0.13961390699957657]
Epoch 4: 100%|██████████| 302/302 [03:56<00:00,  1.28it/s, Average loss over 302: 0.1387542663438183] 
Epoch 5: 100%|██████████| 302/302 [03:56<00:00,  1.28it/s, Average loss over 302: 0.14307312986220616]
Epoch 6: 100%|██████████| 302/302 [03:56<00:00,  1.28it/s, Average loss over 302: 0.1406040512762011] 
Epoch 7: 100%|██████████| 302/302 [03:56<00:00,  1.28it/s, Average loss over 302: 0.13289390187566932]
Epoch 8: 100%|██████████| 302/302 [03:57<00:00,  1.27it/s, Average loss over 302: 0.13059640312158743]
Epoch 9: 100%|██████████| 302/302 [03:55<00:00,  1.28it/s, Average loss o

In [None]:
from IPython.display import FileLink

FileLink("weights/finetune2/stm_epoch_79.pth")