In [None]:
"""Notebook implementing Transformer model trained on MIDI data."""

In [1]:
!unzip /content/DATASET_AUGMENTED.zip -d /content/DATASET_AUGMENTED/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73973.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73973_aug1_shift+6.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73973_aug2_shift+2.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73973_aug3_shift-7.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73981.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73981_aug1_shift-5.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73981_aug2_shift+2.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73981_aug3_shift+6.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73982.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73982_aug1_shift-4.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/73982_aug2_shift-1.mid  
  inflating: /content/DATASET_AUGMENTED/DATASET_AUGMENTED/739

In [2]:
!pip install -qq torch midi_neural_processor pretty_midi tqdm utils

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━[0m [32m4.2/5.6 MB[0m [31m133.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m90.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m124.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m91.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
!pip install mlflow

Collecting mlflow
  Downloading mlflow-2.22.0-py3-none-any.whl.metadata (30 kB)
Collecting mlflow-skinny==2.22.0 (from mlflow)
  Downloading mlflow_skinny-2.22.0-py3-none-any.whl.metadata (31 kB)
Collecting alembic!=1.10.0,<2 (from mlflow)
  Downloading alembic-1.15.2-py3-none-any.whl.metadata (7.3 kB)
Collecting docker<8,>=4.0.0 (from mlflow)
  Downloading docker-7.1.0-py3-none-any.whl.metadata (3.8 kB)
Collecting graphene<4 (from mlflow)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==2.22.0->mlflow)
  Downloading databricks_sdk-0.52.0-py3-none-any.whl.metadata (39 kB)
Collecting fastapi<1 (from mlflow-skinny==2.22.0->mlflow)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn<1 (from mlflow-skinny==2.22.0->mlflow)
  Downloading uvicorn-0.34.2-py3-none-any.whl.metadata (6.5 k

In [None]:
import os
import random

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tqdm import tqdm
import mlflow

import midi_neural_processor.processor as midi_tokenizer

In [None]:
EVENT_RANGE = midi_tokenizer.RANGE_NOTE_ON
EVENT_RANGE += midi_tokenizer.RANGE_NOTE_OFF
EVENT_RANGE += midi_tokenizer.RANGE_TIME_SHIFT
EVENT_RANGE += midi_tokenizer.RANGE_VEL

CONFIG = {
    "max_sequence_len": 1024,
    "block_size": 512,
    "embedding_dim": 256,
    "num_heads": 8,
    "num_layers": 6,
    "batch_size": 16,
    "token_pad": EVENT_RANGE,
    "token_sos": EVENT_RANGE + 1,
    "token_eos": EVENT_RANGE + 2,
    "vocab_size": EVENT_RANGE + 3,
    "seed": 42,
    "model_out": "midi_transformer_model.pt"
}

In [None]:
DATASET_DIR = "/content/DATASET_AUGMENTED/DATASET_AUGMENTED"

all_midis = [f for f in os.listdir(DATASET_DIR) if f.endswith(".mid")]
random.shuffle(all_midis)

total_midis = len(all_midis)
midis_train = int(0.9 * total_midis)
midis_validation = int(0.1 * total_midis)

train_files = all_midis[:midis_train]
validation_files = all_midis[midis_train:]

In [7]:
print(len(train_files))
print(len(validation_files))
print(len(all_midis))

31730
3526
35256


In [None]:
list_of_seqs = []
MAX_SEQ_LEN = CONFIG['max_sequence_len']
MAX_SEQ_LEN_MUSIC_ONLY = MAX_SEQ_LEN - 2

print(f"Encoding {total_midis} MIDI files into sequences.")

for f in tqdm(all_midis, desc="Encoding MIDI files into sequences"):
    tokens = midi_tokenizer.encode_midi(os.path.join(DATASET_DIR, f))

    for i in range(0, len(tokens), MAX_SEQ_LEN_MUSIC_ONLY):
        chunk = tokens[i : i + MAX_SEQ_LEN_MUSIC_ONLY]

        if len(chunk) < MAX_SEQ_LEN_MUSIC_ONLY:
            pad_count = MAX_SEQ_LEN_MUSIC_ONLY - len(chunk)
            chunk = chunk + [CONFIG['token_pad']] * pad_count

        seq = [CONFIG['token_sos']] + chunk + [CONFIG['token_eos']]
        assert len(seq) == MAX_SEQ_LEN

        list_of_seqs.append(seq)

print(f"Total token sequences: {len(list_of_seqs)}")

Encoding 35256 MIDI files into sequences.


Encoding MIDI files into sequences: 100%|██████████| 35256/35256 [03:35<00:00, 163.87it/s]

Total token sequences: 47422





In [None]:
NUM_SEQUENCES = len(list_of_seqs)
print(NUM_SEQUENCES)

train_seqs = list_of_seqs[:int(0.9 * NUM_SEQUENCES)]
val_seqs = list_of_seqs[int(0.9 * NUM_SEQUENCES):]

print(len(train_seqs))
print(len(val_seqs))

47422
42679
4743


In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_data = torch.tensor(train_seqs, dtype=torch.long, device=DEVICE)
print(train_data.shape, train_data.dtype)
print(train_data[:100])

test_data = torch.tensor(val_seqs, dtype=torch.long, device=DEVICE)
print(test_data.shape, test_data.dtype)
print(test_data[:100])

torch.Size([42679, 1024]) torch.int64
tensor([[389, 295, 376,  ..., 388, 388, 390],
        [389, 355, 295,  ..., 198, 387, 390],
        [389,  72, 264,  ..., 256, 387, 390],
        ...,
        [389, 271, 178,  ..., 388, 388, 390],
        [389, 355, 355,  ..., 388, 388, 390],
        [389, 295, 376,  ..., 388, 388, 390]], device='cuda:0')
torch.Size([4743, 1024]) torch.int64
tensor([[389, 315, 376,  ..., 388, 388, 390],
        [389, 376,  60,  ..., 388, 388, 390],
        [389, 376,  74,  ..., 376,  65, 390],
        ...,
        [389, 376,  67,  ..., 388, 388, 390],
        [389, 267, 376,  ..., 388, 388, 390],
        [389, 376,  54,  ..., 388, 388, 390]], device='cuda:0')


In [12]:
X_train = train_data[:, :-1]
Y_train = train_data[:, 1:]

X_val = test_data[:, :-1]
Y_val = test_data[:, 1:]

In [13]:
train_ds = TensorDataset(X_train, Y_train)
val_ds   = TensorDataset(X_val,   Y_val)

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    drop_last=True
)

In [None]:
class Head(nn.Module):
    """
    Individual head responsible for applying the attention mechanism and
    relative positional embeddings.
    """

    def __init__(self, head_size, n_embd, block_size, dropout=0.1, relative_pos=True):
        """
        Initialize the AttentionHead.

        Args:
            head_size (int): The size of the attention head.
            embedding_dim (int): The input embedding dimension.
            block_size (int): The maximum sequence length supported.
            dropout (float, optional): Dropout probability. Defaults to 0.1.
            use_relative_position (bool, optional): Whether to apply relative positional embeddings.
        """

        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.block_size = block_size
        self.relative_pos = relative_pos

        if relative_pos:
            # relative_embeddings: [block_size, head_size]
            self.relative_embeddings = nn.Parameter(torch.randn(CONFIG['max_sequence_len'], head_size))
        else:
            self.relative_embeddings = None

    def _skew(self, x):
        """
        Apply skewing operation for relative positional attention.

        Args:
            attention_scores (torch.Tensor): The attention scores tensor of shape (batch_size, seq_len, seq_len).

        Returns:
            torch.Tensor: Skewed attention scores tensor.
        """

        B, T1, T2 = x.size()
        x = F.pad(x, (1, 0))              # (B, T, T+1)
        x = x.view(B, T2 + 1, T1)
        x = x[:, 1:, :]
        return x.transpose(1, 2)


    def forward(self, x):

        """
        Forward pass for computing attention outputs.

        Args:
            input_sequence (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, head_size).
        """

        _, T, _ = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # Compute attention scores
        att = q @ k.transpose(-2, -1) / (k.size(-1) ** 0.5)  # (B, T, T)

        if self.relative_pos:
            # (head_size, T)
            relative_embeddings_transposed = self.relative_embeddings[:T, :].transpose(0, 1)
            rel_att = q @ relative_embeddings_transposed  # (B, T, T)
            rel_att = self._skew(rel_att)  # Skew to match shape
            att = att + rel_att

        # Apply causal mask
        mask = torch.tril(torch.ones(T, T, device=x.device))
        att = att.masked_fill(mask == 0, float("-inf"))

        att = F.softmax(att, dim=-1)
        att = self.dropout(att)
        out = att @ v  # (B, T, head_size)
        return out


In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention mechanism with optional relative positional embeddings.

    Combines multiple independent attention heads and projects the result back to the
    original embedding space.
    """

    def __init__(self, num_heads, n_embd, block_size, dropout=0.1, relative_pos=True):
        """
        Initialize the MultiHeadAttention module.

        Args:
            num_heads (int): Number of attention heads.
            embedding_dim (int): Dimension of the input embeddings.
            block_size (int): Maximum sequence length supported.
            dropout (float, optional): Dropout probability. Defaults to 0.1.
            use_relative_position (bool, optional): Whether to apply relative positional embeddings. Defaults to True.
        """

        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([
            Head(head_size, n_embd, block_size, dropout, relative_pos=relative_pos)
            for _ in range(num_heads)
        ])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Forward pass to compute multi-head attention outputs.

        Args:
            input_sequence (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, embedding_dim).
        """
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

In [None]:
class FeedForward(nn.Module):
    """
    Feed-forward neural network with activation and dropout.

    This is typically used as the second component of a Transformer block.
    """

    def __init__(self, n_embd, dropout=0.1):
        """
        Initialize the FeedForward module.

        Args:
            embedding_dim (int): Dimension of the input and output embeddings.
            dropout (float, optional): Dropout probability. Defaults to 0.1.
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        """
        Forward pass through the feed-forward network.

        Args:
            input_sequence (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of the same shape.
        """
        return self.net(x)


class Block(nn.Module):
    """
    Single Transformer block consisting of multi-head attention,
    feed-forward network, and layer normalization.
    """

    def __init__(self, n_embd, n_head, block_size, dropout=0.1):
        """
        Initialize the TransformerBlock.

        Args:
            embedding_dim (int): Dimension of the input embeddings.
            num_heads (int): Number of attention heads.
            block_size (int): Maximum sequence length.
            dropout (float, optional): Dropout probability. Defaults to 0.1.
        """
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        """
        Forward pass through the Transformer block.

        Args:
            input_sequence (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of the same shape.
        """
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
"""Model class, processes input and target token sequences through
stacked transformer decoder blocks"""

class TransformerMIDILanguageModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.block_size    = config['block_size']
        self.vocab_size    = config['vocab_size']
        self.sos_token_id  = config['token_sos']
        self.pad_token_id  = config['token_pad']
        self.eos_token_id  = config['token_eos']

        self.token_embedding_table = nn.Embedding(
            self.vocab_size,
            config['embedding_dim'],
            padding_idx=self.pad_token_id
        )

        self.blocks = nn.Sequential(*[
            Block(n_embd=config['embedding_dim'],
                  n_head=config['num_heads'],
                  block_size=config['block_size'],
                  dropout=0.1)
            for _ in range(config['num_layers'])
            ])

        self.ln_f = nn.LayerNorm(config['embedding_dim'])
        self.lm_head = nn.Linear(config['embedding_dim'], self.vocab_size)


    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx)

        x = self.blocks(tok_emb)
        x = self.ln_f(x)

        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1),
                                   ignore_index=self.pad_token_id)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, eos_token_id, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)

            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, top_k, dim=-1)
                min_vals = v[:, -1].unsqueeze(1)
                logits[logits < min_vals] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)

            if eos_token_id is not None:
                # check every item in batch
                eos_mask = idx_next.eq(eos_token_id).view(-1)
                if eos_mask.all():
                    break

        return idx

In [None]:
model = TransformerMIDILanguageModel(CONFIG).to(DEVICE)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
print(model)

6.507911 M parameters
TransformerMIDILanguageModel(
  (token_embedding_table): Embedding(391, 256, padding_idx=388)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=256, out_features=32, bias=False)
            (query): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1024, out_features=256, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((256,), eps=1e-05

In [20]:
@torch.no_grad()
def estimate_loss(model, val_loader, eval_iters, device):
    model.eval()
    losses = []
    for i, (xb, yb) in enumerate(val_loader):
        if i >= eval_iters:
            break
        xb, yb = xb.to(device), yb.to(device)
        _, loss = model(xb, yb)
        losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    config,
    device,
    max_epochs=50,
    eval_iters=200,
    experiment_name="MIDI_Transformer_Exp",
    run_name=None,
    start_epoch=0
):

    patience = 7

    mlflow.set_experiment(experiment_name)
    with mlflow.start_run(run_name=run_name):
        # Log all config parameters
        for k, v in config.items():
            mlflow.log_param(k, v)

        mlflow.log_param("optimizer", "AdamW")
        mlflow.log_param("learning_rate", optimizer.param_groups[0]['lr'])

        best_val_loss = float("inf")
        train_loss_log = []
        val_loss_log = []
        early_stop_counter = 0

        for epoch in range(start_epoch, start_epoch + max_epochs):
            model.train()
            running_loss = 0.0
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{start_epoch + max_epochs}")

            for xb, yb in pbar:
                xb, yb = xb.to(device), yb.to(device)
                optimizer.zero_grad(set_to_none=True)
                _, loss = model(xb, yb)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_postfix(train_loss=running_loss / (pbar.n + 1))

            avg_train_loss = running_loss / len(train_loader)
            train_loss_log.append(avg_train_loss)
            mlflow.log_metric("train_loss", avg_train_loss, step=epoch)

            # Evaluate on val set
            avg_val_loss = estimate_loss(model, val_loader, eval_iters, device)
            val_loss_log.append(avg_val_loss)
            mlflow.log_metric("val_loss", avg_val_loss, step=epoch)

            print(
                f"Epoch {epoch+1}: train loss = {avg_train_loss:.4f}, val loss = {avg_val_loss:.4f}"
                )

            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                best_model_path = f"best_model_{epoch}.pth"
                torch.save(model.state_dict(), best_model_path)
                mlflow.log_artifact(best_model_path)
                print(f"Best model saved at epoch {epoch+1} with val loss {best_val_loss:.4f}")

                best_optimizer_path = f"best_optimizer_{epoch+1}.pth"
                torch.save(optimizer.state_dict(), best_optimizer_path)
                mlflow.log_artifact(best_optimizer_path)
            else:
                early_stop_counter += 1
                print(f"No improvement for {early_stop_counter} epoch.")

            if early_stop_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}.")
                break

        return model, train_loss_log, val_loss_log

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

trained_model, train_log, val_log = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    config=CONFIG,
    device=DEVICE,
    max_epochs=50,
    eval_iters=200,
    experiment_name="MIDI_Transformer_Training",
    run_name="Run_01"
)

Epoch 1/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.919]


Epoch 1: train loss = 0.9186, val loss = 0.8663
Best model saved at epoch 1 with val loss 0.8663


Epoch 2/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.866]


Epoch 2: train loss = 0.8659, val loss = 0.8345
Best model saved at epoch 2 with val loss 0.8345


Epoch 3/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.835]


Epoch 3: train loss = 0.8345, val loss = 0.8044
Best model saved at epoch 3 with val loss 0.8044


Epoch 4/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.804]


Epoch 4: train loss = 0.8039, val loss = 0.7756
Best model saved at epoch 4 with val loss 0.7756


Epoch 5/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.778]


Epoch 5: train loss = 0.7777, val loss = 0.7529
Best model saved at epoch 5 with val loss 0.7529


Epoch 6/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.757]


Epoch 6: train loss = 0.7565, val loss = 0.7343
Best model saved at epoch 6 with val loss 0.7343


Epoch 7/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.739]


Epoch 7: train loss = 0.7390, val loss = 0.7204
Best model saved at epoch 7 with val loss 0.7204


Epoch 8/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.725]


Epoch 8: train loss = 0.7246, val loss = 0.7102
Best model saved at epoch 8 with val loss 0.7102


Epoch 9/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.712]


Epoch 9: train loss = 0.7122, val loss = 0.7000
Best model saved at epoch 9 with val loss 0.7000


Epoch 10/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.701]


Epoch 10: train loss = 0.7010, val loss = 0.6890
Best model saved at epoch 10 with val loss 0.6890


Epoch 11/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.668]


Epoch 11: train loss = 0.6677, val loss = 0.6193
Best model saved at epoch 11 with val loss 0.6193


Epoch 12/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.63]


Epoch 12: train loss = 0.6302, val loss = 0.5947
Best model saved at epoch 12 with val loss 0.5947


Epoch 13/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.608]


Epoch 13: train loss = 0.6084, val loss = 0.5793
Best model saved at epoch 13 with val loss 0.5793


Epoch 14/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.593]


Epoch 14: train loss = 0.5931, val loss = 0.5708
Best model saved at epoch 14 with val loss 0.5708


Epoch 15/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.582]


Epoch 15: train loss = 0.5819, val loss = 0.5638
Best model saved at epoch 15 with val loss 0.5638


Epoch 16/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.574]


Epoch 16: train loss = 0.5739, val loss = 0.5584
Best model saved at epoch 16 with val loss 0.5584


Epoch 17/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.566]


Epoch 17: train loss = 0.5662, val loss = 0.5560
Best model saved at epoch 17 with val loss 0.5560


Epoch 18/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.56]


Epoch 18: train loss = 0.5595, val loss = 0.5478
Best model saved at epoch 18 with val loss 0.5478


Epoch 19/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.553]


Epoch 19: train loss = 0.5534, val loss = 0.5430
Best model saved at epoch 19 with val loss 0.5430


Epoch 20/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.548]


Epoch 20: train loss = 0.5482, val loss = 0.5391
Best model saved at epoch 20 with val loss 0.5391


Epoch 21/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.542]


Epoch 21: train loss = 0.5425, val loss = 0.5380
Best model saved at epoch 21 with val loss 0.5380


Epoch 22/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.538]


Epoch 22: train loss = 0.5380, val loss = 0.5331
Best model saved at epoch 22 with val loss 0.5331


Epoch 23/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.533]


Epoch 23: train loss = 0.5331, val loss = 0.5293
Best model saved at epoch 23 with val loss 0.5293


Epoch 24/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.529]


Epoch 24: train loss = 0.5289, val loss = 0.5247
Best model saved at epoch 24 with val loss 0.5247


Epoch 25/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.525]


Epoch 25: train loss = 0.5246, val loss = 0.5228
Best model saved at epoch 25 with val loss 0.5228


Epoch 26/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.521]


Epoch 26: train loss = 0.5208, val loss = 0.5187
Best model saved at epoch 26 with val loss 0.5187


Epoch 27/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.517]


Epoch 27: train loss = 0.5171, val loss = 0.5165
Best model saved at epoch 27 with val loss 0.5165


Epoch 28/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.514]


Epoch 28: train loss = 0.5137, val loss = 0.5151
Best model saved at epoch 28 with val loss 0.5151


Epoch 29/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.51]


Epoch 29: train loss = 0.5102, val loss = 0.5112
Best model saved at epoch 29 with val loss 0.5112


Epoch 30/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.507]


Epoch 30: train loss = 0.5071, val loss = 0.5101
Best model saved at epoch 30 with val loss 0.5101


Epoch 31/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.504]


Epoch 31: train loss = 0.5040, val loss = 0.5074
Best model saved at epoch 31 with val loss 0.5074


Epoch 32/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.501]


Epoch 32: train loss = 0.5014, val loss = 0.5051
Best model saved at epoch 32 with val loss 0.5051


Epoch 33/50: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.498]


Epoch 33: train loss = 0.4983, val loss = 0.5036
Best model saved at epoch 33 with val loss 0.5036


Epoch 34/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.496]


Epoch 34: train loss = 0.4959, val loss = 0.5037
No improvement for 1 epoch.


Epoch 35/50: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.493]


Epoch 35: train loss = 0.4930, val loss = 0.5001
Best model saved at epoch 35 with val loss 0.5001


Epoch 36/50:  83%|████████▎ | 2223/2667 [09:33<01:54,  3.88it/s, train_loss=0.49]

Additional training on 50 epochs model

In [None]:
model_trained_on_50_epochs = TransformerMIDILanguageModel(CONFIG).to(DEVICE)
model_trained_on_50_epochs.load_state_dict(torch.load("best_model_49.pth"))
model_trained_on_50_epochs.eval()
optimizer_additional_50_epochs = torch.optim.Adam(model_trained_on_50_epochs.parameters(), lr=1e-4)
print(model_trained_on_50_epochs)

TransformerMIDILanguageModel(
  (token_embedding_table): Embedding(391, 256, padding_idx=388)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=256, out_features=32, bias=False)
            (query): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1024, out_features=256, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=T

In [None]:
trained_model, train_log, val_log = train_model(
    model=model_trained_on_50_epochs,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer_additional_50_epochs,
    config=CONFIG,
    device=DEVICE,
    max_epochs=100,
    eval_iters=200,
    experiment_name="MIDI_Transformer_Training",
    run_name="Run_Resume",
    start_epoch=50
)

2025/05/10 21:15:32 INFO mlflow.tracking.fluent: Experiment with name 'MIDI_Transformer_Training' does not exist. Creating a new experiment.
Epoch 51/150: 100%|██████████| 2667/2667 [11:29<00:00,  3.87it/s, train_loss=1.37]


Epoch 51: train loss = 1.3693, val loss = 0.9553
Best model saved at epoch 51 with val loss 0.9553


Epoch 52/150: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.916]


Epoch 52: train loss = 0.9158, val loss = 0.8827
Best model saved at epoch 52 with val loss 0.8827


Epoch 53/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.865]


Epoch 53: train loss = 0.8647, val loss = 0.8518
Best model saved at epoch 53 with val loss 0.8518


Epoch 54/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.834]


Epoch 54: train loss = 0.8340, val loss = 0.8231
Best model saved at epoch 54 with val loss 0.8231


Epoch 55/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.804]


Epoch 55: train loss = 0.8038, val loss = 0.7943
Best model saved at epoch 55 with val loss 0.7943


Epoch 56/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.778]


Epoch 56: train loss = 0.7775, val loss = 0.7695
Best model saved at epoch 56 with val loss 0.7695


Epoch 57/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.757]


Epoch 57: train loss = 0.7570, val loss = 0.7501
Best model saved at epoch 57 with val loss 0.7501


Epoch 58/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.739]


Epoch 58: train loss = 0.7391, val loss = 0.7357
Best model saved at epoch 58 with val loss 0.7357


Epoch 59/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.723]


Epoch 59: train loss = 0.7233, val loss = 0.7207
Best model saved at epoch 59 with val loss 0.7207


Epoch 60/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.702]


Epoch 60: train loss = 0.7015, val loss = 0.6716
Best model saved at epoch 60 with val loss 0.6716


Epoch 61/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.657]


Epoch 61: train loss = 0.6569, val loss = 0.6288
Best model saved at epoch 61 with val loss 0.6288


Epoch 62/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.628]


Epoch 62: train loss = 0.6275, val loss = 0.6086
Best model saved at epoch 62 with val loss 0.6086


Epoch 63/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.609]


Epoch 63: train loss = 0.6089, val loss = 0.5958
Best model saved at epoch 63 with val loss 0.5958


Epoch 64/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.595]


Epoch 64: train loss = 0.5949, val loss = 0.5863
Best model saved at epoch 64 with val loss 0.5863


Epoch 65/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.584]


Epoch 65: train loss = 0.5839, val loss = 0.5815
Best model saved at epoch 65 with val loss 0.5815


Epoch 66/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.575]


Epoch 66: train loss = 0.5750, val loss = 0.5742
Best model saved at epoch 66 with val loss 0.5742


Epoch 67/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.568]


Epoch 67: train loss = 0.5677, val loss = 0.5679
Best model saved at epoch 67 with val loss 0.5679


Epoch 68/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.561]


Epoch 68: train loss = 0.5609, val loss = 0.5646
Best model saved at epoch 68 with val loss 0.5646


Epoch 69/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.555]


Epoch 69: train loss = 0.5548, val loss = 0.5589
Best model saved at epoch 69 with val loss 0.5589


Epoch 70/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.549]


Epoch 70: train loss = 0.5492, val loss = 0.5561
Best model saved at epoch 70 with val loss 0.5561


Epoch 71/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.544]


Epoch 71: train loss = 0.5440, val loss = 0.5526
Best model saved at epoch 71 with val loss 0.5526


Epoch 72/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.539]


Epoch 72: train loss = 0.5390, val loss = 0.5493
Best model saved at epoch 72 with val loss 0.5493


Epoch 73/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.535]


Epoch 73: train loss = 0.5346, val loss = 0.5466
Best model saved at epoch 73 with val loss 0.5466


Epoch 74/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.531]


Epoch 74: train loss = 0.5311, val loss = 0.5437
Best model saved at epoch 74 with val loss 0.5437


Epoch 75/150: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.527]


Epoch 75: train loss = 0.5270, val loss = 0.5405
Best model saved at epoch 75 with val loss 0.5405


Epoch 76/150: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.523]


Epoch 76: train loss = 0.5230, val loss = 0.5384
Best model saved at epoch 76 with val loss 0.5384


Epoch 77/150: 100%|██████████| 2667/2667 [11:27<00:00,  3.88it/s, train_loss=0.52]


Epoch 77: train loss = 0.5198, val loss = 0.5363
Best model saved at epoch 77 with val loss 0.5363


Epoch 78/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.516]


Epoch 78: train loss = 0.5164, val loss = 0.5353
Best model saved at epoch 78 with val loss 0.5353


Epoch 79/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.514]


Epoch 79: train loss = 0.5135, val loss = 0.5324
Best model saved at epoch 79 with val loss 0.5324


Epoch 80/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.51]


Epoch 80: train loss = 0.5101, val loss = 0.5293
Best model saved at epoch 80 with val loss 0.5293


Epoch 81/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.507]


Epoch 81: train loss = 0.5070, val loss = 0.5290
Best model saved at epoch 81 with val loss 0.5290


Epoch 82/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.505]


Epoch 82: train loss = 0.5046, val loss = 0.5242
Best model saved at epoch 82 with val loss 0.5242


Epoch 83/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.502]


Epoch 83: train loss = 0.5016, val loss = 0.5238
Best model saved at epoch 83 with val loss 0.5238


Epoch 84/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.499]


Epoch 84: train loss = 0.4992, val loss = 0.5217
Best model saved at epoch 84 with val loss 0.5217


Epoch 85/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.497]


Epoch 85: train loss = 0.4967, val loss = 0.5203
Best model saved at epoch 85 with val loss 0.5203


Epoch 86/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.494]


Epoch 86: train loss = 0.4940, val loss = 0.5187
Best model saved at epoch 86 with val loss 0.5187


Epoch 87/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.492]


Epoch 87: train loss = 0.4918, val loss = 0.5178
Best model saved at epoch 87 with val loss 0.5178


Epoch 88/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.49]


Epoch 88: train loss = 0.4898, val loss = 0.5151
Best model saved at epoch 88 with val loss 0.5151


Epoch 89/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.488]


Epoch 89: train loss = 0.4876, val loss = 0.5149
Best model saved at epoch 89 with val loss 0.5149


Epoch 90/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.486]


Epoch 90: train loss = 0.4858, val loss = 0.5137
Best model saved at epoch 90 with val loss 0.5137


Epoch 91/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.483]


Epoch 91: train loss = 0.4834, val loss = 0.5110
Best model saved at epoch 91 with val loss 0.5110


Epoch 92/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.88it/s, train_loss=0.481]


Epoch 92: train loss = 0.4812, val loss = 0.5126
No improvement for 1 epoch.


Epoch 93/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.479]


Epoch 93: train loss = 0.4795, val loss = 0.5098
Best model saved at epoch 93 with val loss 0.5098


Epoch 94/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.478]


Epoch 94: train loss = 0.4778, val loss = 0.5078
Best model saved at epoch 94 with val loss 0.5078


Epoch 95/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.476]


Epoch 95: train loss = 0.4763, val loss = 0.5077
Best model saved at epoch 95 with val loss 0.5077


Epoch 96/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.474]


Epoch 96: train loss = 0.4739, val loss = 0.5066
Best model saved at epoch 96 with val loss 0.5066


Epoch 97/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.472]


Epoch 97: train loss = 0.4723, val loss = 0.5062
Best model saved at epoch 97 with val loss 0.5062


Epoch 98/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.471]


Epoch 98: train loss = 0.4710, val loss = 0.5055
Best model saved at epoch 98 with val loss 0.5055


Epoch 99/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.469]


Epoch 99: train loss = 0.4693, val loss = 0.5058
No improvement for 2 epoch.


Epoch 100/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.468]


Epoch 100: train loss = 0.4676, val loss = 0.5040
Best model saved at epoch 100 with val loss 0.5040


Epoch 101/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.466]


Epoch 101: train loss = 0.4664, val loss = 0.5022
Best model saved at epoch 101 with val loss 0.5022


Epoch 102/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.465]


Epoch 102: train loss = 0.4647, val loss = 0.5009
Best model saved at epoch 102 with val loss 0.5009


Epoch 103/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.463]


Epoch 103: train loss = 0.4628, val loss = 0.5005
Best model saved at epoch 103 with val loss 0.5005


Epoch 104/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.462]


Epoch 104: train loss = 0.4618, val loss = 0.5005
Best model saved at epoch 104 with val loss 0.5005


Epoch 105/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.46]


Epoch 105: train loss = 0.4601, val loss = 0.4995
Best model saved at epoch 105 with val loss 0.4995


Epoch 106/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.459]


Epoch 106: train loss = 0.4590, val loss = 0.4999
No improvement for 3 epoch.


Epoch 107/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.457]


Epoch 107: train loss = 0.4574, val loss = 0.4985
Best model saved at epoch 107 with val loss 0.4985


Epoch 108/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.456]


Epoch 108: train loss = 0.4564, val loss = 0.4984
Best model saved at epoch 108 with val loss 0.4984


Epoch 109/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.455]


Epoch 109: train loss = 0.4549, val loss = 0.4973
Best model saved at epoch 109 with val loss 0.4973


Epoch 110/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.454]


Epoch 110: train loss = 0.4538, val loss = 0.4948
Best model saved at epoch 110 with val loss 0.4948


Epoch 111/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.452]


Epoch 111: train loss = 0.4523, val loss = 0.4965
No improvement for 4 epoch.


Epoch 112/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.451]


Epoch 112: train loss = 0.4514, val loss = 0.4945
Best model saved at epoch 112 with val loss 0.4945


Epoch 113/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.45]


Epoch 113: train loss = 0.4503, val loss = 0.4950
No improvement for 5 epoch.


Epoch 114/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.449]


Epoch 114: train loss = 0.4493, val loss = 0.4931
Best model saved at epoch 114 with val loss 0.4931


Epoch 115/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.448]


Epoch 115: train loss = 0.4479, val loss = 0.4924
Best model saved at epoch 115 with val loss 0.4924


Epoch 116/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.447]


Epoch 116: train loss = 0.4468, val loss = 0.4909
Best model saved at epoch 116 with val loss 0.4909


Epoch 117/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.446]


Epoch 117: train loss = 0.4459, val loss = 0.4921
No improvement for 6 epoch.


Epoch 118/150: 100%|██████████| 2667/2667 [11:28<00:00,  3.87it/s, train_loss=0.445]


Epoch 118: train loss = 0.4446, val loss = 0.4915
No improvement for 7 epoch.
Early stopping at epoch 118.
