In [None]:
#!:bash

In [None]:
%pip install -r requirements.txt

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import segmentation_models_pytorch as smp
from nuimages.nuimages import NuImages          


from train import train

# import gc

# gc.collect()
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}")

NUM_EPOCHS = 2
NUM_WORKERS = os.cpu_count()
BATCH_SIZE = 16
NUM_EXAMPLES=1 
FREEZE=False

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from nuimages.nuimages import NuImages
from nuimages.utils.utils import name_to_index_mapping

class NuImagesDataset(Dataset):
    def __init__(self,
                 dataroot,
                 version,
                 desired_class_ids=None,    # e.g. [5,7,10]
                 background_id=0,           # remapped value for all others
                 transform=None,
                 target_transform=None,
                 sensor_channels=None):
        # --- 1) Initialize nuImages and filter key-frame tokens ---
        self.nuim = NuImages(dataroot=dataroot,
                             version=version,
                             lazy=True,
                             verbose=False)

        # make channel lookup O(1)
        self.sensor_channels = set(sensor_channels or
                                   ['CAM_FRONT','CAM_FRONT_LEFT','CAM_FRONT_RIGHT'])

        sd_tokens = []
        for sd in self.nuim.sample_data:
            if not sd['is_key_frame']:
                continue
            # get sensor channel via the calibrated_sensor → sensor join
            cs = self.nuim.get('calibrated_sensor', sd['calibrated_sensor_token'])
            channel = self.nuim.get('sensor', cs['sensor_token'])['channel']
            if channel in self.sensor_channels:
                sd_tokens.append(sd['token'])
        self.sd_tokens = sd_tokens

        self.transform        = transform
        self.target_transform = target_transform

        # --- 2) Build a one-time LUT for fast mask remapping ---
        # get the full mapping of names→indices that get_segmentation uses
        mapping = name_to_index_mapping(self.nuim.category)
        max_idx = max(mapping.values())
        lut = np.full((max_idx + 1,), fill_value=background_id, dtype=np.uint8)

        # only keep your desired IDs
        desired = set(desired_class_ids or [])
        for cls_id in desired:
            if 0 <= cls_id <= max_idx:
                lut[cls_id] = cls_id

        self.lut = lut

    def __len__(self):
        return len(self.sd_tokens)

    def __getitem__(self, idx):
        sd_token = self.sd_tokens[idx]
        sample_data = self.nuim.get('sample_data', sd_token)
        img_path = os.path.join(self.nuim.dataroot,
                                sample_data['filename'])
        image = Image.open(img_path).convert('RGB')

        # get raw sem‐seg mask (H×W int32 array)
        sem_mask, _ = self.nuim.get_segmentation(sd_token)
        raw = sem_mask.astype(np.uint8)

        # single lookup = blazing fast C operation
        remapped = self.lut[raw]

        mask = Image.fromarray(remapped)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

resize_size = (256, 256)
original_size = (1024, 2048)

input_transform = transforms.Compose([
    transforms.Resize(resize_size, interpolation=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

# Mask transforms (nearest-neighbor resize + to tensor)
target_transform = transforms.Compose([
    transforms.Resize(resize_size, interpolation=Image.NEAREST),
    transforms.PILToTensor(),
    transforms.Lambda(lambda t: t.squeeze(0).long()),
])

# root = os.getenv('NUIMAGES')
root = '/var/tmp/MultiTask_vs_Yolo_Unet/full_nuImages'
train_version = 'v1.0-train'
val_version   = 'v1.0-val'
"""
 1 → animal
 2 → human.pedestrian.adult
 3 → human.pedestrian.child
 4 → human.pedestrian.construction_worker
 5 → human.pedestrian.personal_mobility
 6 → human.pedestrian.police_officer
 7 → human.pedestrian.stroller
 8 → human.pedestrian.wheelchair
 9 → movable_object.barrier
10 → movable_object.debris
11 → movable_object.pushable_pullable
12 → movable_object.trafficcone
13 → static_object.bicycle_rack
14 → vehicle.bicycle
15 → vehicle.bus.bendy
16 → vehicle.bus.rigid
17 → vehicle.car
18 → vehicle.construction
19 → vehicle.emergency.ambulance
20 → vehicle.emergency.police
21 → vehicle.motorcycle
22 → vehicle.trailer
23 → vehicle.truck
24 → flat.driveable_surface
31 → vehicle.ego
"""
desired_ids =  [9, 10, 11, 12, 13, 24]
# Datasets
train_dataset = NuImagesDataset(root, train_version, desired_class_ids=desired_ids)
                                # transform=input_transform,
                                # target_transform=target_transform)

val_dataset   = NuImagesDataset(root, val_version)
                                # transform=input_transform,
                                # target_transform=target_transform)

# train_loader = DataLoader(train_dataset,
#                           batch_size=BATCH_SIZE,
#                           shuffle=True,
#                           num_workers=NUM_WORKERS)

# val_loader   = DataLoader(val_dataset,
#                           batch_size=BATCH_SIZE,
#                           shuffle=False,
#                           num_workers=NUM_WORKERS)


# for images, masks in train_loader:
#     print(images.shape)  # torch.Size([4, 3, H, W])
#     print(masks.shape)   # torch.Size([4, H, W])
#     print(masks.dtype)   # torch.int64
#     break


indices = [3, 4, 5]

for idx in indices:
    image, mask = train_dataset[idx]
    # image: PIL.Image RGB
    # mask:  PIL.Image (mode 'L'), values [0..Nclasses]
    img_arr = np.array(image)
    mask_arr = np.array(mask)

    plt.figure(figsize=(15,15))
    plt.imshow(img_arr)
    # overlay mask in semi‐transparent red where mask>0
    plt.imshow(mask_arr, alpha=0.4)
    plt.axis('off')
    plt.title(f"Sample {idx}")
plt.show()

In [None]:
from nuimages.utils.utils import name_to_index_mapping# Build the mapping
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import segmentation_models_pytorch as smp
from nuimages.nuimages import NuImages          

dataroot = '/var/tmp/MultiTask_vs_Yolo_Unet/full_nuImages'
version = 'v1.0-train'
val_version   = 'v1.0-val'
nuim = NuImages(dataroot=dataroot,
                             version=version,
                             lazy=True,
                             verbose=False)
mapping = name_to_index_mapping(nuim.category)

# Display it
for name, idx in mapping.items():
    print(f"{idx:2d} → {name}")


In [None]:
num_classes = len(nuim_train.category)

In [None]:

indices = [3, 4, 5]

for idx in indices:
    image, mask = train_dataset[idx]
    # image: PIL.Image RGB
    # mask:  PIL.Image (mode 'L'), values [0..Nclasses]
    img_arr = np.array(image)
    mask_arr = np.array(mask)

    plt.figure(figsize=(6,6))
    plt.imshow(img_arr)
    # overlay mask in semi‐transparent red where mask>0
    plt.imshow(np.where(mask_arr>0, 1, np.nan), 
               cmap='Reds', alpha=0.4, vmin=0, vmax=1)
    plt.axis('off')
    plt.title(f"Sample {idx}")
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import segmentation_models_pytorch as smp
from nuimages.nuimages import NuImages          


from train import train
from typing import Optional
# import gc

# gc.collect()
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is {device}")

NUM_EPOCHS = 2
NUM_WORKERS = os.cpu_count()
BATCH_SIZE = 16
NUM_EXAMPLES=1 
FREEZE=False

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from nuimages.nuimages import NuImages
from nuimages.utils.utils import name_to_index_mapping

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from nuimages.nuimages import NuImages
from nuimages.utils.utils import name_to_index_mapping

IGNORE_INDEX = 255

class NuImagesDataset(Dataset):
    def __init__(self,
                 dataroot,
                 version,
                 desired_class_ids:
                   Optional[list]=None,
                 transform=None,
                 target_transform=None,
                 sensor_channels=None):

        self.nuim = NuImages(dataroot=dataroot,
                             version=version,
                             lazy=True, verbose=False)
        cameras = set(sensor_channels or
                      ['CAM_FRONT','CAM_FRONT_LEFT','CAM_FRONT_RIGHT'])
        self.sd_tokens = [
            sd['token']
            for sd in self.nuim.sample_data
            if sd['is_key_frame']
            and self.nuim.shortcut('sample_data','sensor',sd['token'])['channel']
                in cameras
        ]
        self.transform        = transform
        self.target_transform = target_transform
        name2idx = name_to_index_mapping(self.nuim.category)
        max_old   = max(name2idx.values())
        desired = desired_class_ids or []
        self.num_classes = 1 + len(desired)
        lut = np.full((max_old+1,), fill_value=IGNORE_INDEX, dtype=np.uint8)
        lut[0] = 0
        for new_id, old_id in enumerate(desired, start=1):
            lut[old_id] = new_id

        self.lut = lut

    def __len__(self):
        return len(self.sd_tokens)

    def __getitem__(self, idx):
        sd_token = self.sd_tokens[idx]
        sd       = self.nuim.get('sample_data', sd_token)
        img_path = os.path.join(self.nuim.dataroot, sd['filename'])
        image    = Image.open(img_path).convert('RGB')
        sem_mask, _ = self.nuim.get_segmentation(sd_token)
        raw         = sem_mask.astype(np.uint8)
        remapped = self.lut[raw] 
        mask = Image.fromarray(remapped)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask


resize_size = (256, 256)
original_size = (1024, 2048)

input_transform = transforms.Compose([
    transforms.Resize(resize_size, interpolation=Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225]),
])

# Mask transforms (nearest-neighbor resize + to tensor)
target_transform = transforms.Compose([
    transforms.Resize(resize_size, interpolation=Image.NEAREST),
    transforms.PILToTensor(),
    transforms.Lambda(lambda t: t.squeeze(0).long()),
])

# root = os.getenv('NUIMAGES')
root = '/var/tmp/MultiTask_vs_Yolo_Unet/full_nuImages'
train_version = 'v1.0-train'
val_version   = 'v1.0-val'

desired_ids =  [9, 10, 11, 12, 13, 24]
# Datasets
train_dataset = NuImagesDataset(root, train_version, desired_class_ids=desired_ids,
                                transform=input_transform,
                                target_transform=target_transform)

val_dataset   = NuImagesDataset(root, val_version, desired_class_ids=desired_ids,
                                transform=input_transform,
                                target_transform=target_transform)

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=NUM_WORKERS, drop_last=True, pin_memory=True, persistent_workers=True)

val_loader   = DataLoader(val_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=NUM_WORKERS, drop_last=True, pin_memory=True, persistent_workers=True)

nuim_train = NuImages(dataroot=root, version=train_version, lazy=True, verbose=False)
num_classes = len(desired_ids)

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=num_classes,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)


from train import train

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    total_steps=len(train_loader)*NUM_EPOCHS,
    pct_start=0.3,
    anneal_strategy='cos'
)

train(model=model,
      optimizer=optimizer,
      train_loader=train_loader,
      val_loader=val_loader,
      num_epochs=NUM_EPOCHS,
      num_classes=num_classes,
      scheduler=scheduler,
      freeze_encoder=False, 
      plot_every=1)

Device is cuda
39.93558478355408
40.05818200111389
40.119895458221436
40.1733193397522
40.6914758682251
41.00242018699646
41.005342960357666
41.00759935379028
45.8759081363678
45.88336133956909
45.97970390319824
46.002256870269775
46.57042932510376
46.640608072280884
46.64455056190491
46.64609408378601
51.25409913063049
51.75851774215698
52.165422439575195
52.16622614860535
52.16653823852539
52.61426830291748
52.61571526527405
52.618473291397095
56.49860191345215
57.42618656158447
57.445735692977905
58.53869390487671
58.53903317451477
58.5398154258728
58.540103912353516
58.541067361831665
62.511783838272095
62.51215100288391
62.64790487289429
64.24093866348267
64.24178266525269
64.24202299118042
64.24286890029907
64.2453989982605
68.04193067550659
68.0483787059784
68.04879093170166
69.56587600708008
69.56619095802307
70.03074359893799
70.0311176776886
70.0316903591156
74.16999340057373
74.1706440448761
74.17082524299622
74.58752846717834
74.58821034431458
75.05122113227844
75.051628351

In [None]:
nuim_train = NuImages(dataroot=root, version=train_version, lazy=True, verbose=False)
[cat["name"] for cat in nuim_train.category]

In [None]:
st = nuim.sample_data[0]
nuim.shortcut('sample_data', '', st['token'])

In [4]:
import os
import cv2
import re
from nuimages.nuimages import NuImages

# Initialize nuImages
dataroot = '/var/tmp/MultiTask_vs_Yolo_Unet/full_nuImages'
version = 'v1.0-train'
nuim = NuImages(dataroot=dataroot, version=version, verbose=True)

# Get a sample
sample = nuim.sample[50]
key_camera_token = sample['key_camera_token']
sample_data = nuim.get('sample_data', key_camera_token)
image_path = os.path.join(nuim.dataroot, sample_data['filename'])

# Load image
img = cv2.imread(image_path)
if img is None:
    raise FileNotFoundError(f"Image not found: {image_path}")

# Define which categories to keep
#  - exactly 'vehicle.car'
#  - anything that starts with 'human.pedestrian'
#  - exactly 'animal'
def keep_category(name: str) -> bool:
    return (
        name == 'vehicle.car' or
        name == 'animal' or
        name.startswith('human.pedestrian')
    )

# Fetch and filter annotations for this frame
filtered_anns = []
for ann in nuim.object_ann:
    if ann['sample_data_token'] != key_camera_token:
        continue
    cat = nuim.get('category', ann['category_token'])['name']
    if keep_category(cat):
        filtered_anns.append((ann, cat))

# Draw only the filtered boxes
for ann, category in filtered_anns:
    xmin, ymin, xmax, ymax = map(int, ann['bbox'])
    cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
    cv2.putText(
        img,
        category,
        (xmin, max(0, ymin - 10)),
        cv2.FONT_HERSHEY_SIMPLEX,
        0.5,
        (0, 255, 0),
        1,
        cv2.LINE_AA
    )

# Save result
output_dir = os.path.join(dataroot, version, 'first_sample')
os.makedirs(output_dir, exist_ok=True)
out_path = os.path.join(output_dir, 'first_sample_filtered.png')
cv2.imwrite(out_path, img)
print(f"Filtered annotation image saved to {out_path}")


Loading nuImages tables for version v1.0-train...
Done loading in 0.000 seconds (lazy=True).
Loaded 67279 sample(s) in 0.073s,
Loaded 872181 sample_data(s) in 2.742s,
Loaded 557715 object_ann(s) in 3.312s,
Loaded 25 category(s) in 0.000s,
Filtered annotation image saved to /var/tmp/MultiTask_vs_Yolo_Unet/full_nuImages/v1.0-train/first_sample/first_sample_filtered.png


In [None]:
import sys
sys.path.insert(0, '/home/devdem/MultiTask_vs_Yolo_Unet/HybridNets')  # замените на реальный абсолютный путь до папки HybridNets


In [5]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Optional, Any, Tuple
from torch import nn
from torch.utils.data import DataLoader, Dataset
from IPython.display import clear_output
from tqdm.notebook import tqdm
import torch.optim as optim
torch.backends.cudnn.benchmark = True
import torchvision.transforms as T
from PIL import Image
import gc

from nuimages.nuimages import NuImages
from hybridnets.model.hybridnets import HybridNetsBackbone
from hybridnets.utils.criterion import MultiTaskLoss
from hybridnets.utils.utils import load_pretrained

# ----------------------
# Utility: plot losses
# ----------------------
def plot_losses(train_losses: List[float], val_losses: List[float]):
    clear_output(wait=True)
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].plot(range(1, len(train_losses) + 1), train_losses, label='train')
    ax[0].plot(range(1, len(val_losses) + 1), val_losses, label='val')
    ax[0].set_title('Loss')
    ax[0].legend()
    ax[0].set_xlabel('Epoch')

    ax[1].plot(range(1, len(train_losses) + 1), np.exp(train_losses), label='train')
    ax[1].plot(range(1, len(val_losses) + 1), np.exp(val_losses), label='val')
    ax[1].set_title('Perplexity')
    ax[1].legend()
    ax[1].set_xlabel('Epoch')

    plt.show()

# ------------------------------------
# Dataset for NuImages + HybridNets
# ------------------------------------
class NuImagesHybridDataset(Dataset):
    def __init__(self,
                 dataroot: str,
                 version: str,
                 transform: Optional[Any] = None,
                 target_size: Tuple[int, int] = (320, 640)):
        self.nuim = NuImages(dataroot=dataroot, version=version, lazy=True, verbose=False)
        self.sd_tokens = [
            sd['token'] for sd in self.nuim.sample_data
            if sd['is_key_frame'] and
               self.nuim.shortcut('sample_data', 'sensor', sd['token'])['channel'] == 'CAM_FRONT'
        ]
        self.transform = transform
        self.target_size = target_size
        self.category_to_id = {cat['name']: i for i, cat in enumerate(self.nuim.category)}

    def __len__(self) -> int:
        return len(self.sd_tokens)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        tok = self.sd_tokens[idx]
        sd = self.nuim.get('sample_data', tok)
        img_file = os.path.join(self.nuim.dataroot, sd['filename'])
        img = Image.open(img_file).convert('RGB')
        orig_w, orig_h = img.size
        new_h, new_w = self.target_size

        # Segmentation mask
        sem_mask, _ = self.nuim.get_segmentation(tok)
        seg = Image.fromarray(sem_mask.astype(np.uint8))

        # Object annotations
        anns = [ann for ann in self.nuim.object_ann if ann['sample_data_token'] == tok]
        bboxes, labels = [], []
        for ann in anns:
            xmin, ymin, xmax, ymax = ann['bbox']
            cx = (xmin + xmax) / 2 / orig_w
            cy = (ymin + ymax) / 2 / orig_h
            bw = (xmax - xmin) / orig_w
            bh = (ymax - ymin) / orig_h
            cat = self.nuim.get('category', ann['category_token'])['name']
            cls_id = self.category_to_id[cat]
            bboxes.append([cx, cy, bw, bh])
            labels.append(cls_id)

        bboxes = torch.tensor(bboxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)

        # Lane placeholder
        lane = Image.fromarray(np.zeros((orig_h, orig_w), dtype=np.uint8))

        # Transforms
        if self.transform:
            img = self.transform(img)
            seg = self.transform(seg)
            lane = self.transform(lane)

        # seg: [1,H,W], lane: [1,H,W]
        return img, seg.squeeze(0).long(), (bboxes, labels), lane.squeeze(0).long()

# ------------------------------------------------
# Training / Validation Loops (HybridNets version)
# ------------------------------------------------
def training_epoch(model: nn.Module,
                   optimizer: torch.optim.Optimizer,
                   criterion: MultiTaskLoss,
                   loader: DataLoader,
                   desc: str) -> float:
    device = next(model.parameters()).device
    model.train()
    total_loss = 0.0
    for imgs, segs, (bboxes, lbls), lanes in tqdm(loader, desc=desc):
        imgs, segs, lanes = imgs.to(device), segs.to(device), lanes.to(device)
        # detection targets stay on CPU for criterion
        optimizer.zero_grad()
        out_seg, out_det, out_lane = model(imgs)
        loss, _ = criterion(
            out_seg, segs,
            out_det, bboxes, lbls,
            out_lane, lanes
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def validation_epoch(model: nn.Module,
                     criterion: MultiTaskLoss,
                     loader: DataLoader,
                     desc: str) -> float:
    device = next(model.parameters()).device
    model.eval()
    val_loss = 0.0
    for imgs, segs, (bboxes, lbls), lanes in tqdm(loader, desc=desc):
        imgs, segs, lanes = imgs.to(device), segs.to(device), lanes.to(device)
        out_seg, out_det, out_lane = model(imgs)
        loss, _ = criterion(
            out_seg, segs,
            out_det, bboxes, lbls,
            out_lane, lanes
        )
        val_loss += loss.item()
    return val_loss / len(loader)

# ----------------------
# Main training script
# ----------------------
def main():
    # Hyperparams
    DATAROOT = '/var/tmp/nuImages'
    VERSION = 'v1.0-mini'
    BATCH_SIZE = 4
    NUM_WORKERS = 2
    NUM_EPOCHS = 5
    LR = 1e-4

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Dataset & Loader
    transform = T.Compose([
        T.Resize((320, 640)),
        T.ToTensor()
    ])
    train_ds = NuImagesHybridDataset(DATAROOT, VERSION, transform)
    val_ds   = NuImagesHybridDataset(DATAROOT, VERSION, transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,num_workers=NUM_WORKERS)

    # Model & Loss
    num_classes = len(train_ds.category_to_id)
    num_lanes = 1
    model = HybridNetsBackbone(num_classes=num_classes, num_lanes=num_lanes).to(device)
    load_pretrained(model)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    criterion = MultiTaskLoss().to(device)

    # Training
    best_val = float('inf')
    train_losses, val_losses = [], []
    for epoch in range(1, NUM_EPOCHS + 1):
        train_loss = training_epoch(model, optimizer, criterion, train_loader, f"Train {epoch}/{NUM_EPOCHS}")
        val_loss   = validation_epoch(model, criterion, val_loader, f"Val   {epoch}/{NUM_EPOCHS}")
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), 'hybridnets_best.pth')

    torch.save(model.state_dict(), 'hybridnets_last.pth')
    plot_losses(train_losses, val_losses)

if __name__ == '__main__':
    main()


ModuleNotFoundError: No module named 'hybridnets'