## Import libraries

In [1]:
import os
import random
from functools import partial
import zipfile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import torch
import torch.nn as nn
import torch.optim as optim


from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence

from TST import TSTransformerEncoderClassiregressor, NoFussCrossEntropyLoss
from tqdm.auto import tqdm

## Model

In [None]:
feat_dim = 10
max_len = 120
d_model = 32
n_heads = 2
num_layers = 2
dim_feedforward = 32
num_classes = 3
dropout=0.1
pos_encoding = "learn"
activation="relu"
norm = "BatchNorm"
freeze=False

In [None]:
# Gpu available
if torch.cuda.is_available():
    device = torch.device("cuda")

In [None]:
class_model = TSTransformerEncoderClassiregressor(feat_dim, max_len, d_model, n_heads, num_layers, dim_feedforward, num_classes, dropout, pos_encoding, activation, norm, freeze).to(device)

## Data

In [None]:
class Trajectory(Dataset):
    def __init__(self, traj, label, index, max_len):
        """
        traj: features
        lable: trawlers, gillnetters or seiners
        index: which trajectory the segments belonging to
        max_len: max length of segments
        """
        super().__init__()
        self.traj = traj
        self.label = torch.tensor(label)
        self.index = index
        self.len = torch.tensor([t.shape[0] for t in traj])
        self.max_len = max_len
        self.make_padding()
        
    def __getitem__(self, index):
        return self.traj[index], self.label[index], self.index[index], self.len[index], self.mask[index]
    
    def __len__(self):
        return len(self.index)
    
    def make_padding(self):
        mask = torch.zeros(self.__len__(), self.max_len, dtype=torch.bool)
        
        for i in range(mask.shape[0]):
            
            mask[i, :self.len[i]] = 1
        self.mask = mask

In [None]:
def collate_fn(batch, max_len):
    traj, label, index, ll, mask = zip(*batch)
    padded_sequences = pad_sequence(traj, batch_first=True)
    extra_padding = max_len - padded_sequences.size(1)
    if extra_padding !=0:
        padding = torch.full((padded_sequences.size(0), extra_padding, padded_sequences.size(-1)), 0, dtype=padded_sequences.dtype)
        padded_sequences = torch.cat([padded_sequences, padding], dim=1)
    return padded_sequences, torch.stack(label), index, ll, torch.stack(mask)

# max_len: max length of segments
collate_fn = partial(collate_fn, max_len=120)

Segmentations = DataLoader(trajectory, batch_size=5000, shuffle=True, collate_fn=collate_fn)

## Training

In [None]:
# loss function
class_loss = NoFussCrossEntropyLoss()

# optimizer
lr = 1e-3
optimizer = optim.Adam(class_model.parameters(), lr=lr)

In [None]:
# training
def train(model):
    model.train()
    train_loss = 0
    pbar = tqdm(enumerate(Segmentations))
    for batch_idx, data in pbar:
        optimizer.zero_grad()
        traj, label, _, ll, mask = data
        traj = traj.to(device)
        mask = mask.to(device)
        label = label.to(device)
        output = class_model(traj, mask)
        loss = class_loss(output, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f ' %
            (batch_idx + 1, len(Segmentations), train_loss/(batch_idx+1))
        )
    return train_loss/(batch_idx+1)

In [None]:
pbar = tqdm(range(epochs))
for epoch in pbar:
    pbar.set_description('Epoch: %d' % (epoch))