In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision.models.feature_extraction import create_feature_extractor

from scipy.optimize import linear_sum_assignment
import numpy as np
from tqdm.auto import tqdm
import json
from datetime import datetime
import cv2

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import random

In [None]:
# Основные компоненты обучения

class FilesDataset(torch.utils.data.Dataset):
    def __init__(self, t_type, v_list, views, path_data, path_gen_img=None, path_gen_depth=None, deterministic=False):
        self.t_type = t_type
        self.v_list = v_list
        self.views = views
        self.path_data = path_data
        self.path_gen_img = path_gen_img
        self.path_gen_depth = path_gen_depth
        if deterministic:
            np.random.seed(1)
        np.random.shuffle(self.v_list)

    def __len__(self):
        return len(self.v_list)

    def __getitem__(self, idx):
        view = self.v_list[idx]

        p_json = self.path_data + '/' + view + '.json'
        with open(p_json, "r", encoding="utf-8") as f:
            json_info = json.load(f)

        masks_dict = {}
        for k, mask_name in self.views[view].items():
            mask_path = self.path_data + '/' + mask_name
            mask = torch.tensor(cv2.imread(mask_path)[:, :, 0] / 255).unsqueeze(0).to(torch.float32)
            masks_dict[k] = mask

        if self.t_type == 'sd':
            p_depth = self.path_data + '/' + view + '_depth.jpg'
            img_tensor = torch.tensor(cv2.imread(p_depth)[:, :, 0] / 255).unsqueeze(0).to(torch.float32)

        if self.t_type == 'img':
            p_img = self.path_gen_img + '/' + view + '_gi.jpg'
            img_tensor = torch.tensor(cv2.imread(p_img).permute(2, 0, 1) / 255).unsqueeze(0).to(torch.float32)

        if self.t_type == 'img':
            p_img = self.path_gen_img + '/' + view + '_gi.jpg'
            img = torch.tensor(cv2.imread(p_img).permute(2, 0, 1) / 255).unsqueeze(0).to(torch.float32)
            p_depth = self.path_gen_depth + '/' + view + '_gd.jpg'
            depth = torch.tensor(cv2.imread(p_depth)[:, :, 0] / 255).unsqueeze(0).to(torch.float32)
            img_tensor = torch.cat([depth, img], dim=0)

        return json_info, masks_dict, img_tensor

def custom_collate_fn(batch):
    dict1_list = []
    dict2_list = []
    tensor_list = []

    for dict1, dict2, tensor in batch:
        dict1_list.append(dict1)
        dict2_list.append(dict2)
        tensor_list.append(tensor)
    tensor_batch = torch.stack(tensor_list)

    return dict1_list, dict2_list, tensor_batch

def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device,
               writer,
               scheduler=None):
    model.train()
    loss_fn.train()
    losses = [0, 0, 0, 0, 0, 0]

    for batch, (json_info, masks_dicts, img_tensor) in tqdm(enumerate(dataloader), total=len(dataloader), leave=False,
                                                            desc='Train'):
        img_tensor = img_tensor.to(device)

        back_mask, masks, mask_vectors, image_vector = model(img_tensor)
        ancors_pred = extract_ancor_pred(mask_vectors)
        ancors_json = [extract_ancor_dict(gt_json) for gt_json in json_info]
        matching = hungarian_matching(ancors_pred, ancors_json)

        l1 = point_loss(mask_vectors, json_info, matching)
        l2 = aux_loss(mask_vectors, json_info, matching)
        l3 = prob_loss(mask_vectors, json_info, matching)
        l4 = masks_loss(back_mask, masks, masks_dicts, matching, json_info)
        l5 = image_loss(image_vector, json_info)

        l_list = [l1, l2, l3, l4, l5]
        loss = loss_fn()
        losses[0] += loss
        for i in range(1, 6):
            losses[i] = l_list[i].item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

    for i in range(6):
        losses[i] /= len(dataloader)

    return losses

def test_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               device: torch.device,
               writer,
               scheduler = None):
    model.train()
    loss_fn.train()
    losses = [0,0,0,0,0,0]

    for batch, (json_info, masks_dicts, img_tensor) in tqdm(enumerate(dataloader), total=len(dataloader), leave=False, desc='Train'):
        img_tensor = img_tensor.to(device)

        with torch.no_grad():
            back_mask, masks, mask_vectors, image_vector = model(img_tensor)
        ancors_pred = extract_ancor_pred(mask_vectors)
        ancors_json = [extract_ancor_dict(gt_json) for gt_json in json_info]
        matching = hungarian_matching(ancors_pred, ancors_json)

        l1 = point_loss(mask_vectors, json_info, matching)
        l2 = aux_loss(mask_vectors, json_info, matching)
        l3 = prob_loss(mask_vectors, json_info, matching)
        l4 = masks_loss(back_mask, masks, masks_dicts, matching, json_info)
        l5 = image_loss(image_vector, json_info)

        l_list = [l1,l2,l3,l4,l5]
        loss = loss_fn()
        losses[0] += loss
        for i in range(1, 6):
            losses[i] = l_list[i].item()

    for i in range(6):
        losses[i] /= len(dataloader)

    return losses


def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device,
          path: str,
          writer,
          scheduler=None):
    metrices = ["train_loss_total",
                "train_loss_point",
                "train_loss_aux",
                "train_loss_prob",
                "train_loss_mask",
                "train_loss_image",
                "valid_loss_total",
                "valid_loss_point",
                "valid_loss_aux",
                "valid_loss_prob",
                "valid_loss_mask",
                "valid_loss_image"]

    results = {k: [] for k in metrices}

    model.to(device)
    loss_fn.to(device)

    for epoch in tqdm(range(epochs)):
        train_losses = train_step(model,
                                  train_dataloader,
                                  loss_fn,
                                  optimizer,
                                  device,
                                  writer,
                                  scheduler)
        test_losses = test_step(model,
                                test_dataloader,
                                loss_fn,
                                device,
                                writer,
                                scheduler)

        ep = '000'+str(epoch + 1)[-3:]
        stroka = [f"Epoch: {ep}"]
        for i, l in enumerate(train_losses + test_losses):
            results[metrices[i]].append(l)
            s = f"{metrices[i]}: {l:.3f}"
            stroka += [s]
        print(' | '.join(stroka))

        now = str(datetime.now())[:-7].replace(" ", "-").replace(":", "-")
        torch.save(model.state_dict(), path + f"/{now}.pth")
        with open(path + f'/{now}.json', 'w') as json_file:
            json.dump(results, json_file, indent=4)

    return results

In [None]:
# Функции для расчёта ошибок при обучении

def hungarian_matching(pred, targets):
    """
    pred: (B, N, 3)
    targets: list of B tensors of shape (M_i, 2)

    Returns:
        list of B tuples: (pred_indices, target_indices)
    """
    batch_size = pred.shape[0]
    assignments = []
    for i in range(batch_size):
        pred_points = pred[i, :, :2]  # (N, 2)
        target_points = targets[i]  # (M, 2)

        if target_points.shape[0] == 0:
            assignments.append(([], []))
            continue

        # Вычисляем матрицу стоимостей (евклидово расстояние)
        cost = torch.cdist(pred_points, target_points, p=2).cpu().numpy()  # shape: (N, M)
        row_ind, col_ind = linear_sum_assignment(cost)
        assignments.append((row_ind, col_ind))
    # Сначала индекс предсказанного вектора, потом индекс целевого
    return assignments

# Получение якорей из одного json - на выходе формат переменный
def extract_ancor_dict(gt_json):
    ancors = []
    for k in gt_json['houses'].keys():
        s = list(gt_json['houses'][k].keys())[0]
        line = torch.tensor(gt_json['houses'][k][s][1][:2]).to(torch.float)
        ancor = line.mean(dim=0)
        ancors += [ancor.tolist()]
    return torch.tensor(ancors)

# Получение якорей для всего батча - на выходе формат постоянный
def extract_ancor_pred(vectors):
    with torch.no_grad():
        ancors = vectors.cpu()[:,:,-8:-4].view(*vectors.shape[:2], 2, 2).mean(dim=-2)
    return ancors

# Ошибка по нужным координатам для всего батча
def point_loss(mask_vectors, gt_list, matching):
    b, o, p = mask_vectors.shape
    reals = torch.full((b,o,12), float('nan'))
    if b != len(matching):
        print('Не совпадают размеры')
        return None
    if b != len(matching):
        print('Не совпадают размеры')
        return None
    for i in range(b):
        gt_sam = gt_list[i]['houses']
        hos = list(gt_sam.keys())
        for p, match in list(zip(*matching[i])):
            k = hos[match]
            s = list(gt_sam[k].keys())[0]
            reals[i, p, 4] = gt_sam[k][s][1][0][0]
            reals[i, p, 5] = gt_sam[k][s][1][0][1]
            reals[i, p, 6] = gt_sam[k][s][1][1][0]
            reals[i, p, 7] = gt_sam[k][s][1][1][1]
            if "right" in gt_sam[k].keys():
                reals[i, p, 0] = gt_sam[k]["right"][1][2][0]
                reals[i, p, 1] = gt_sam[k]["right"][1][2][1]
                reals[i, p, 2] = gt_sam[k]["right"][1][3][0]
                reals[i, p, 3] = gt_sam[k]["right"][1][3][1]
            if "left" in gt_sam[k].keys():
                reals[i, p, 8] = gt_sam[k]["left"][1][2][0]
                reals[i, p, 9] = gt_sam[k]["left"][1][2][1]
                reals[i, p, 10] = gt_sam[k]["left"][1][3][0]
                reals[i, p, 11] = gt_sam[k]["left"][1][3][1]
    reals = (reals-512)/512
    mse = torch.nn.MSELoss()
    mask = ~torch.isnan(reals)
    preds = mask_vectors[:,:,5:]
    loss = mse(preds[mask], reals.to(preds.device)[mask])
    return loss

# Ошибка по углу фасада для всего батча
def aux_loss(mask_vectors, gt_list, matching):
    b, o, p = mask_vectors.shape
    reals = torch.full((b,o,2), float('nan'))
    if b != len(matching):
        print('Не совпадают размеры')
        return None
    for i in range(b):
        gt_sam = gt_list[i]['houses']
        hos = list(gt_sam.keys())
        for p, match in list(zip(*matching[i])):
            k = hos[match]
            if "right" in gt_sam[k].keys():
                reals[i, p, 0] = gt_sam[k]["right"][2]
            if "left" in gt_sam[k].keys():
                reals[i, p, 1] = gt_sam[k]["left"][2]
    reals /= 90
    mse = torch.nn.MSELoss()
    mask = ~torch.isnan(reals)
    preds = mask_vectors[:,:,3:5]
    loss = mse(preds[mask], reals.to(preds.device)[mask])
    return loss

# Ошибка о вероятностям для всего батча
def prob_loss(mask_vectors, gt_list, matching):
    b, o, p = mask_vectors.shape
    reals = torch.zeros(b, o, 3)
    if b != len(matching):
        print('Не совпадают размеры')
        return None
    for i in range(b):
        reals[i, :, 0][matching[i][0]] = 1
        gt_sam = gt_list[i]['houses']
        hos = list(gt_sam.keys())
        for p, match in list(zip(*matching[i])):
            k = hos[match]
            if "right" in gt_sam[k].keys():
                reals[i, p, 1] = 1
            if "left" in gt_sam[k].keys():
                reals[i, p, 2] = 1
    bce = torch.nn.BCEWithLogitsLoss()
    preds = mask_vectors[:,:,:3]
    loss = bce(preds, reals.to(preds.device))
    return loss

def angle_with_horizontal(o, v):
    o = torch.tensor(o)
    v = torch.tensor(v)
    d = v - o  # Вектор отрезка
    d_norm = torch.norm(d)
    angle_rad = torch.arcsin(d[2] / d_norm)  # Z-компонента
    angle_deg = angle_rad * 180 / torch.pi
    return angle_deg.item()

def image_loss(preds, gt):
    mse = torch.nn.MSELoss()
    angles = []
    for i in gt:
        angles.append([angle_with_horizontal(i['o'], i['v']), i['t']/90])
    angles = torch.tensor(angles).to(torch.float32)
    loss = mse(preds, angles.to(preds.device))
    return loss

# Ошибка масок по всему батчу: masks_dicts [{'0_left':_, '1_right':_}, ...]
def masks_loss(back_mask, masks, masks_dicts, matching, gt_list):
    loss = 0.0
    count = 0
    cel = torch.nn.CrossEntropyLoss()
    b, o, c, h, w = masks.shape
    if b != len(matching):
        print('Не совпадают размеры')
        return None
    for i in range(b):
        target_mask = []
        predic_mask = []
        gt_sam = gt_list[i]['houses']
        hos = list(gt_sam.keys())
        for p, match in list(zip(*matching[i])):
            k = hos[match]
            if "right" in gt_sam[k].keys():
                target_mask.append(masks_dicts[i][k+'_right'])
                predic_mask.append(masks[i, p, 0].unsqueeze(0))
            if "left" in gt_sam[k].keys():
                target_mask.append(masks_dicts[i][k+'_left'])
                predic_mask.append(masks[i, p, 1].unsqueeze(0))
        real_m_1 = torch.stack(target_mask, dim=0)
        real_m_0 = (1-real_m_1.sum(dim=0)).unsqueeze(0)
        real_m = torch.cat([real_m_0, real_m_1], dim=0)
        pred_m_1 = torch.cat(predic_mask, dim=0)
        pred_m_0 = back_mask[i]
        pred_m = torch.cat([pred_m_0, pred_m_1], dim=0).unsqueeze(0)
        real_idx = real_m.argmax(dim=0).long()
        loss += cel(pred_m, real_idx.to(pred_m.device))
        count += 1
    return loss/count

class LossWeighting(torch.nn.Module):
    def __init__(self, num_losses):
        super().__init__()
        self.log_vars = torch.nn.Parameter(torch.zeros(num_losses))
    def forward(self, losses):
        total_loss = 0
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])
            weighted_loss = precision * loss + self.log_vars[i]
            total_loss += weighted_loss
        return total_loss

In [None]:
# Модули модели

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.ReLU())
            prev_dim = h
        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

class MHAttentionMap(nn.Module):
    def __init__(self, query_dim, hidden_dim, num_heads, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k):
        q = self.q_linear(q)
        k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
        weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size())
        return weights

def _expand(tensor, length: int):
    return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)

class MaskHeadConv(nn.Module):
    def __init__(self, dim, fpn_dims, context_dim):
        super().__init__()
        inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
        self.dim = dim
        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = torch.nn.GroupNorm(8, dim)
        self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
        self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
        self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
        self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
        self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
        self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
        self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
        self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
        self.out_lay = torch.nn.Conv2d(inter_dims[4], 2, 3, padding=1)

        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)

    def forward(self, x, bbox_mask, fpns):
        b, s, h, _, _ = bbox_mask.shape
        x = torch.cat([_expand(x, s), bbox_mask.flatten(0, 1)], 1)

        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)

        cur_fpn = self.adapter1(fpns[0])
        cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        cur_fpn = self.adapter2(fpns[1])
        cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)

        cur_fpn = self.adapter3(fpns[2])
        cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)
        x = self.out_lay(x)

        x = x.view(b, s*2, *x.shape[-2:])
        x = F.interpolate(x, scale_factor=4, mode='nearest')
        x = x.view(b, s, 2, *x.shape[-2:])
        return x

class Transformer(nn.Module):
    def __init__(self, hidden_dim=128, nheads=8, num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nheads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_encoder_layers)

        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nheads,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_decoder_layers)

    def forward(self, src, query_embed):
        B, C, H, W = src.shape
        src = src.flatten(2).permute(0, 2, 1)
        memory = self.encoder(src)
        tgt = query_embed.unsqueeze(0).expand(B, -1, -1)
        hs = self.decoder(tgt, memory)
        return hs, memory.view(B, H, W, C).permute(0, 3, 1, 2)

class BLETR(nn.Module):
    def __init__(self, num_q=10, in_channels=3, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        self.n = nheads
        self.q = num_q
        return_nodes = ["layer1", "layer2", "layer3", "layer4"]
        backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        if backbone.conv1.in_channels != in_channels:
            backbone.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.feat = create_feature_extractor(backbone, return_nodes=return_nodes)
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        self.transformer = Transformer(hidden_dim, nheads,
                                       num_encoder_layers, num_decoder_layers)
        self.att_map = MHAttentionMap(hidden_dim, 256, nheads)
        self.mask_head = MaskHeadConv(hidden_dim+nheads, [1024, 512, 256], hidden_dim)
        self.query_pos = nn.Parameter(torch.rand(self.q, hidden_dim))
        self.row_embed = nn.Parameter(torch.rand(hidden_dim // 2, 16))
        self.col_embed = nn.Parameter(torch.rand(hidden_dim // 2, 16))
        self.final = torch.nn.Conv2d(self.q*2, self.q*2, 3, padding=1)
        self.background = nn.Conv2d(self.q*2, 1, 1)
        self.bild_p = nn.Linear(hidden_dim, 1)
        self.face_p = nn.Linear(hidden_dim, 2)
        self.fo_aux = MLP(hidden_dim, [hidden_dim]*3, 2)
        self.coords = MLP(hidden_dim, [hidden_dim]*3, 12)
        self.dof = nn.Sequential(
            nn.Conv2d(hidden_dim*2, hidden_dim, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim//2, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim//2, hidden_dim//4, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 2))

    def forward(self, x):
        features = self.feat(x)
        src = self.conv(features["layer4"])
        b, c, h, w = src.shape
        pos = torch.cat([
            self.row_embed.unsqueeze(1).repeat(1, h, 1),
            self.col_embed.unsqueeze(2).repeat(1, 1, w)], dim=0).unsqueeze(0).repeat(b, 1,1,1)
        hs, memory = self.transformer(src+pos, self.query_pos)
        attn = self.att_map(hs, memory)
        masks = self.mask_head(src, attn, [features["layer3"], features["layer2"], features["layer1"]])
        masks = self.final(masks.flatten(1,2))
        back = self.background(masks)
        masks = masks.view(b, self.q, 2, *masks.shape[-2:])
        building = self.bild_p(hs)
        faces = self.face_p(hs)
        camera = self.fo_aux(hs)
        points = self.coords(hs)
        vectors = torch.cat([building, faces, camera, points], dim=-1)
        image = torch.cat([src, memory], dim= 1)
        image = self.dof(image)
        return back, masks, vectors, image

In [None]:
with open(..., "r", encoding="utf-8") as f:
    v_list = json.load(f)

with open(..., "r", encoding="utf-8") as f:
    views = json.load(f)

In [None]:
train_views, valid_views = train_test_split(random.sample(v_list[:17476], 1000), test_size=0.3, random_state=42)
#valid_views, test_views = train_test_split(lear_views, test_size=0.5, random_state=42)

train_dataset = FilesDataset(t_type='sd', 
                    v_list=train_views, 
                    views=views, 
                    path_data=r'E:\city_data')

valid_dataset = FilesDataset(t_type='sd', 
                    v_list=valid_views, 
                    views=views, 
                    path_data=r'E:\city_data')

train_loader = DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=8, collate_fn=custom_collate_fn)

In [None]:
model = BLETR().cuda()
sum(p.numel() for p in model.parameters())

In [None]:
loss_fn = LossWeighting(2)
sum(p.numel() for p in loss_fn.parameters())

In [None]:
n_epoches = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(list(model.parameters())+list(loss_fn.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-5, steps_per_epoch=len(train_loader), epochs=n_epoches, pct_start=0.1)
path = ...

In [None]:
results = train(model, train_loader, valid_loader, optimizer, loss_fn, n_epoches, device, path, writer=None, scheduler=scheduler)