In [5]:
import os
import re
import random
from enum import Enum
from typing import List, Tuple, DefaultDict, Set
from collections import defaultdict, Counter

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, AdamW
from tqdm.auto import tqdm
from Bio import SeqIO
import wandb


os.environ["WANDB_NOTEBOOK_NAME"] = 'GPT_dBG.ipynb'


class DatasetType(Enum):
    HMM = "HMM"
    FASTA = "fasta"



class DatasetConfig:
    def __init__(self, dataset_type, sequence_length: int, num_seqs: int, split_ratio: float):
        self.dataset = dataset_type
        self.sequence_length = sequence_length
        self.num_seqs = num_seqs
        self.split_ratio = split_ratio


class FastaDatasetConfig(DatasetConfig):
    def __init__(self, sequence_length: int, num_seqs: int, file_path: str,
                 stride: int, substrings_per_seq: int, sequences_shuffle: bool = True, split_ratio: float = 0.5):
        super().__init__(DatasetType.FASTA, sequence_length, num_seqs, split_ratio)
        self.file_path = file_path
        self.stride = stride
        self.substrings_per_seq = substrings_per_seq
        self.sequences_shuffle = sequences_shuffle


class HMMDatasetConfig(DatasetConfig):
    def __init__(self, sequence_length: int, num_seqs: int, sparsity: float, num_hidden_states: int,  split_ratio: float = 0.5):
        super().__init__(DatasetType.HMM, sequence_length, num_seqs, split_ratio)
        self.sparsity = sparsity
        self.num_hidden_states = num_hidden_states


class ModelConfig:
    def __init__(self, model_name: str, n_embed: int, n_layer: int, n_head: int):
        self.model_name = model_name
        self.n_embed = n_embed
        self.n_layer = n_layer
        self.n_head = n_head


class TrainingConfig:
    def __init__(self, train_bs: int, val_bs: int, lr: float, weight_decay: float,
                 num_epochs: int, early_stopping_patience: int, print_every: int):
        self.train_bs = train_bs
        self.val_bs = val_bs
        self.lr = lr
        self.weight_decay = weight_decay
        self.num_epochs = num_epochs
        self.early_stopping_patience = early_stopping_patience
        self.print_every = print_every


class Config:
    def __init__(self, dataset_config: DatasetConfig, model_config: ModelConfig, training_config: TrainingConfig):
        self.dataset_config = dataset_config
        self.model_config = model_config
        self.training_config = training_config


class SequenceTokenizer:
    def __init__(self):
        self.alphabet = {'A', 'C', 'G', 'T'}
        self.token_to_idx = {char: i for i, char in enumerate(self.alphabet)}
        self.idx_to_token = {i: char for i, char in enumerate(self.alphabet)}
        self.vocab_size = len(self.token_to_idx)

    def encode(self, sequence: str, return_tensors: str = "pt") -> torch.Tensor:
        tokens = [self.token_to_idx[char] for char in sequence]
        if return_tensors == "pt":
            tokens = torch.tensor(tokens, dtype=torch.long)
        return tokens

    def decode(self, tokens: torch.Tensor) -> str:
        sequence = [self.idx_to_token[token.item()] for token in tokens]
        return ''.join(sequence)


class DNADataset(Dataset):
    def __init__(self, sequences: List[str], tokenizer: SequenceTokenizer):
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, i):
        sequence = self.sequences[i]
        inputs = self.tokenizer.encode(sequence, return_tensors='pt')
        targets = inputs[1:].clone()
        inputs = inputs[:-1]
        return inputs, targets


def compute_char_probabilities(sequences: List[str]) -> None:
    counts = defaultdict(int)
    total_count = 0
    for seq in sequences:
        cd = Counter(seq)
        for c, count in cd.items():
            counts[c] += count
            total_count += count
    pcounts = {char: c / total_count for char, c in counts.items()}
    print(pcounts)


def read_fna(file_path: str, shuffle: bool = False) -> List[str]:
    sequences = []
    with open(file_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences.append(str(record.seq))
    if shuffle:
        random.shuffle(sequences)
    return sequences


def generate_markov_chain(n: int, sparsity: float) -> torch.Tensor:
    transition_probs = torch.rand(n, n)
    for i in range(n):
        num_outgoing_states = torch.poisson(torch.tensor([float(sparsity)])).int().item()
        num_outgoing_states = min(n, max(1, num_outgoing_states))
        _, indices = torch.topk(transition_probs[i], num_outgoing_states)
        mask = torch.zeros_like(transition_probs[i]).scatter_(0, indices, 1).to(torch.bool)
        transition_probs[i] *= mask
    transition_probs = F.normalize(transition_probs, p=1, dim=1)
    return transition_probs


def draw_seq(transition_probs: torch.Tensor, sequence_length: int) -> str:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    current_state = 0
    chain = [current_state]
    while len(chain) < sequence_length:
        next_state = torch.multinomial(transition_probs[current_state], num_samples=1)
        current_state = next_state.item()
        chain.append(current_state)
    seq = [NUCLEOTIDES[s % len(NUCLEOTIDES)] for s in chain]
    return ''.join(seq)


def generate_HMM_dataset(sequence_length: int, N: int, sparsity: float, num_hidden_states: int = None) -> Tuple[List[str], List[str]]:
    if num_hidden_states is None:
        num_hidden_states = sequence_length
    transition_matrix = generate_markov_chain(num_hidden_states, sparsity=sparsity)
    all_seqs = [[], []]
    for _ in range(2):
        seqs = []
        while len(seqs) < N:
            seq = draw_seq(transition_matrix, sequence_length)
            if len(seq) >= sequence_length:
                seqs.append(seq)
        all_seqs[_] = seqs
    return all_seqs


def generate_phylo_dataset(sequence_length: int, mutation_rate: float, N: int) -> List[str]:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    parent_sequence = ''.join(random.choice(NUCLEOTIDES) for _ in range(sequence_length))
    dataset = [parent_sequence]
    for _ in range(N):
        mutated_sequence = ''
        for nucleotide in parent_sequence:
            if random.random() < mutation_rate:
                mutated_sequence += random.choice(NUCLEOTIDES)
            else:
                mutated_sequence += nucleotide
        dataset.append(mutated_sequence)
        parent_sequence = mutated_sequence
    return dataset


def construct_debruijn_graph(dataset: List[str], k: int) -> DefaultDict[str,Set[str]]:
    graph = defaultdict(set)
    for sequence in dataset:
        for i in range(len(sequence) - k):
            graph[sequence[i:i + k]].add(sequence[i + 1:i + k + 1])
    return graph


def extract_substrings(sequences: List[str], sequence_length: int, stride: int, substrings_per_seq: int) -> List[str]:
    substrings = []
    for sequence in sequences:
        for i in range(0, len(sequence) - sequence_length + 1, stride):
            if i // stride > substrings_per_seq:
                break
            seq = sequence[i:i + sequence_length]
            if bool(re.match("^[ACGT]+$", seq)):
                substrings.append(seq)
    return substrings


def set_random_seed(seed: int) -> None:
    torch.manual_seed(seed)
    random.seed(seed)


def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, device: torch.device, description: str) -> float:
    model.train()
    running_loss = []
    bar = tqdm(train_loader, desc=description)
    for inputs, targets in bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs, labels=targets)
        loss = outputs.loss.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        bar.set_postfix({"Train Loss": loss.item()})
    return sum(running_loss) / len(running_loss)


def evaluate(model: torch.nn.Module, val_loader: DataLoader, device: torch.device, description: str) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_count = 0
    total_acc = 0.0
    bar = tqdm(val_loader, desc=description)
    with torch.no_grad():
        for inputs, targets in bar:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs, labels=targets)
            val_loss = outputs.loss.mean()
            predictions = outputs.logits.argmax(dim=-1)
            accuracy = (predictions == targets).cpu().float().mean().item()
            total_loss += val_loss.item() * inputs.size(0)
            total_count += inputs.size(0)
            total_acc += accuracy
            bar.set_postfix({"Val Loss": val_loss.item(), "Val Accuracy": accuracy})
    avg_val_loss = total_loss / total_count
    avg_accuracy = total_acc / len(val_loader)
    return avg_val_loss, avg_accuracy


def train_loop(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: Config) -> None:
    best_val_loss = float('inf')
    num_epochs_no_improve = 0  # Number of epochs with no improvement in validation loss

    try:
        for epoch in range(config.training_config.num_epochs):
            train_loss = train(model, optimizer, train_loader, device, f"Epoch {epoch + 1}/{config.training_config.num_epochs} | Training")
            val_loss, val_acc = evaluate(model, val_loader, device, f"Epoch {epoch + 1}/{config.training_config.num_epochs} | Validation")
            samples = (epoch+1) * len(train_loader) * config.training_config.train_bs
            print(f"Epoch {epoch + 1}/{config.training_config.num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f} | Val Accuracy: {val_acc:.5f}")
            wandb.log({"epoch": epoch + 1, "samples": samples, "train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_acc})
            
            if (epoch + 1) % 10 == 0:
                # Save the model weights as an artifact every 10 epochs
                artifact = wandb.Artifact(f"model_weights", type='model')
                torch.save(model.state_dict(), 'gpt2_dna.pth')
                artifact.add_file('gpt2_dna.pth')
                wandb.log_artifact(artifact)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                num_epochs_no_improve = 0
            else:
                num_epochs_no_improve += 1
                if num_epochs_no_improve >= config.training_config.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch + 1}...")
                    break

    except KeyboardInterrupt:
        print("Interrupted by user")
    finally:
        torch.save(model.state_dict(), 'gpt2_dna.pth')
        wandb.finish()


def load_datasets(dataset_config: DatasetConfig) -> Tuple[List[str], List[str]]:
    if dataset_config.dataset == DatasetType.HMM:
        train_seqs, val_seqs = generate_HMM_dataset(dataset_config.sequence_length,
                                                    N=dataset_config.num_seqs,
                                                    sparsity=dataset_config.sparsity,
                                                    num_hidden_states=dataset_config.num_hidden_states)
    elif dataset_config.dataset == DatasetType.VIRAL:
        set_random_seed(42)
        sequences = read_fna(file_path=dataset_config.file_path, shuffle=dataset_config.sequences_shuffle)
        sequences = sequences[:dataset_config.num_seqs]
        sub_seqs = extract_substrings(sequences,
                                      sequence_length=dataset_config.sequence_length,
                                      stride=dataset_config.stride,
                                      substrings_per_seq=dataset_config.substrings_per_seq)
        compute_char_probabilities(sub_seqs)
        train_size = int(len(sub_seqs) * (1 - dataset_config.split_ratio))
        train_seqs = sub_seqs[:train_size]
        val_seqs = sub_seqs[train_size:]
        train_seqs = train_seqs[:dataset_config.num_seqs]
        val_seqs = val_seqs[:dataset_config.num_seqs]
    return train_seqs, val_seqs


def main(config: Config) -> None:
    dataset_config = config.dataset_config
    model_config = config.model_config
    training_config = config.training_config

    train_seqs, val_seqs = load_datasets(dataset_config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = SequenceTokenizer()
    train_dataset = DNADataset(train_seqs, tokenizer)
    val_dataset = DNADataset(val_seqs, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=training_config.train_bs, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=training_config.val_bs, shuffle=False)

    wandb.init(project='GPT2_DNA', name=f'{dataset_config.dataset.name}', config=config)

    gpt2_config = GPT2Config(vocab_size=tokenizer.vocab_size,
                             n_positions=dataset_config.sequence_length,
                             n_ctx=dataset_config.sequence_length,
                             n_embd=model_config.n_embed,
                             n_layer=model_config.n_layer,
                             n_head=model_config.n_head)

    model = GPT2LMHeadModel(gpt2_config).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)

    optimizer = AdamW(model.parameters(), lr=training_config.lr, weight_decay=training_config.weight_decay)

    train_loop(model, optimizer, train_loader, val_loader, device, config)


if __name__ == "__main__":
    dataset_config = HMMDatasetConfig(sequence_length=1000,
                                   split_ratio=0.5,
                                   sparsity=1.1,
                                   num_hidden_states=500,
                                   num_seqs=100,)
    # dataset_config = DatasetConfig(dataset=DatasetType.HMM,
    #                                file_path="./data/viral.1.1.genomic.fna",
    #                                sequence_length=1000,
    #                                stride=1000,
    #                                split_ratio=0.5,
    #                                substrings_per_seq=20,
    #                                sparsity=1.1,
    #                                num_hidden_states=1000,
    #                                num_seqs=100,
    #                                sequences_shuffle=True)

    model_config = ModelConfig(model_name='gpt2',
                               n_embed=512,
                               n_layer=4,
                               n_head=16)

    training_config = TrainingConfig(train_bs=32,
                                     val_bs=64,
                                     lr=1e-4,
                                     weight_decay=0.00,
                                     num_epochs=200,
                                     early_stopping_patience=5,
                                     print_every=20)

    config = Config(dataset_config=dataset_config,
                    model_config=model_config,
                    training_config=training_config)

    main(config)


Using 2 GPUs for training.




Epoch 1/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]



Epoch 1/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/200 | Train Loss: 0.40098 | Val Loss: 0.15236 | Val Accuracy: 0.00350


Epoch 2/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 2/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 2/200 | Train Loss: 0.13750 | Val Loss: 0.06820 | Val Accuracy: 0.00375


Epoch 3/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 3/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 3/200 | Train Loss: 0.06927 | Val Loss: 0.05490 | Val Accuracy: 0.00367


Epoch 4/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 4/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 4/200 | Train Loss: 0.07406 | Val Loss: 0.04886 | Val Accuracy: 0.00405


Epoch 5/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 5/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 5/200 | Train Loss: 0.05553 | Val Loss: 0.04088 | Val Accuracy: 0.00398


Epoch 6/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 6/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 6/200 | Train Loss: 0.04447 | Val Loss: 0.03570 | Val Accuracy: 0.00409


Epoch 7/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 7/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 7/200 | Train Loss: 0.04177 | Val Loss: 0.03159 | Val Accuracy: 0.00398


Epoch 8/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 8/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 8/200 | Train Loss: 0.03640 | Val Loss: 0.02804 | Val Accuracy: 0.00348


Epoch 9/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 9/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 9/200 | Train Loss: 0.03179 | Val Loss: 0.02622 | Val Accuracy: 0.00396


Epoch 10/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 10/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 10/200 | Train Loss: 0.03405 | Val Loss: 0.02556 | Val Accuracy: 0.00381


Epoch 11/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 11/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 11/200 | Train Loss: 0.03084 | Val Loss: 0.02494 | Val Accuracy: 0.00418


Epoch 12/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 12/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 12/200 | Train Loss: 0.02691 | Val Loss: 0.02437 | Val Accuracy: 0.00391


Epoch 13/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 13/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 13/200 | Train Loss: 0.03600 | Val Loss: 0.02400 | Val Accuracy: 0.00351


Epoch 14/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 14/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 14/200 | Train Loss: 0.02700 | Val Loss: 0.02375 | Val Accuracy: 0.00347


Epoch 15/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 15/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 15/200 | Train Loss: 0.02853 | Val Loss: 0.02328 | Val Accuracy: 0.00346


Epoch 16/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 16/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 16/200 | Train Loss: 0.02765 | Val Loss: 0.02290 | Val Accuracy: 0.00347


Epoch 17/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 17/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 17/200 | Train Loss: 0.02796 | Val Loss: 0.02266 | Val Accuracy: 0.00314


Epoch 18/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 18/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 18/200 | Train Loss: 0.02881 | Val Loss: 0.02252 | Val Accuracy: 0.00314


Epoch 19/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 19/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 19/200 | Train Loss: 0.02783 | Val Loss: 0.02247 | Val Accuracy: 0.00333


Epoch 20/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 20/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 20/200 | Train Loss: 0.02640 | Val Loss: 0.02232 | Val Accuracy: 0.00334


Epoch 21/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 21/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 21/200 | Train Loss: 0.02644 | Val Loss: 0.02206 | Val Accuracy: 0.00330


Epoch 22/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 22/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 22/200 | Train Loss: 0.02624 | Val Loss: 0.02181 | Val Accuracy: 0.00307


Epoch 23/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 23/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 23/200 | Train Loss: 0.03097 | Val Loss: 0.02168 | Val Accuracy: 0.00316


Epoch 24/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 24/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 24/200 | Train Loss: 0.02559 | Val Loss: 0.02182 | Val Accuracy: 0.00382


Epoch 25/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 25/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 25/200 | Train Loss: 0.02828 | Val Loss: 0.02164 | Val Accuracy: 0.00386


Epoch 26/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 26/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 26/200 | Train Loss: 0.02773 | Val Loss: 0.02135 | Val Accuracy: 0.00387


Epoch 27/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 27/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 27/200 | Train Loss: 0.02389 | Val Loss: 0.02115 | Val Accuracy: 0.00386


Epoch 28/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 28/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 28/200 | Train Loss: 0.02573 | Val Loss: 0.02092 | Val Accuracy: 0.00384


Epoch 29/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 29/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 29/200 | Train Loss: 0.02444 | Val Loss: 0.02082 | Val Accuracy: 0.00432


Epoch 30/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 30/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 30/200 | Train Loss: 0.02438 | Val Loss: 0.02070 | Val Accuracy: 0.00436


Epoch 31/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 31/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 31/200 | Train Loss: 0.02375 | Val Loss: 0.02059 | Val Accuracy: 0.00429


Epoch 32/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 32/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 32/200 | Train Loss: 0.02580 | Val Loss: 0.02057 | Val Accuracy: 0.00433


Epoch 33/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 33/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 33/200 | Train Loss: 0.02763 | Val Loss: 0.02054 | Val Accuracy: 0.00433


Epoch 34/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 34/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 34/200 | Train Loss: 0.02411 | Val Loss: 0.02056 | Val Accuracy: 0.00431


Epoch 35/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 35/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 35/200 | Train Loss: 0.02982 | Val Loss: 0.02035 | Val Accuracy: 0.00453


Epoch 36/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 36/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 36/200 | Train Loss: 0.02921 | Val Loss: 0.02045 | Val Accuracy: 0.00436


Epoch 37/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 37/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 37/200 | Train Loss: 0.02281 | Val Loss: 0.02015 | Val Accuracy: 0.00431


Epoch 38/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 38/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 38/200 | Train Loss: 0.02301 | Val Loss: 0.01967 | Val Accuracy: 0.00446


Epoch 39/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 39/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 39/200 | Train Loss: 0.02682 | Val Loss: 0.01956 | Val Accuracy: 0.00455


Epoch 40/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 40/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 40/200 | Train Loss: 0.02138 | Val Loss: 0.02011 | Val Accuracy: 0.00430


Epoch 41/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 41/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 41/200 | Train Loss: 0.02212 | Val Loss: 0.01961 | Val Accuracy: 0.00444


Epoch 42/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 42/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 42/200 | Train Loss: 0.02797 | Val Loss: 0.01916 | Val Accuracy: 0.00436


Epoch 43/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 43/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 43/200 | Train Loss: 0.02144 | Val Loss: 0.01918 | Val Accuracy: 0.00433


Epoch 44/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 44/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 44/200 | Train Loss: 0.02174 | Val Loss: 0.01903 | Val Accuracy: 0.00427


Epoch 45/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 45/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 45/200 | Train Loss: 0.02367 | Val Loss: 0.01873 | Val Accuracy: 0.00438


Epoch 46/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 46/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 46/200 | Train Loss: 0.02191 | Val Loss: 0.01857 | Val Accuracy: 0.00436


Epoch 47/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 47/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 47/200 | Train Loss: 0.02316 | Val Loss: 0.01874 | Val Accuracy: 0.00431


Epoch 48/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 48/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 48/200 | Train Loss: 0.02196 | Val Loss: 0.01899 | Val Accuracy: 0.00445


Epoch 49/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 49/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 49/200 | Train Loss: 0.02100 | Val Loss: 0.01829 | Val Accuracy: 0.00435


Epoch 50/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 50/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 50/200 | Train Loss: 0.02187 | Val Loss: 0.01795 | Val Accuracy: 0.00430


Epoch 51/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 51/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 51/200 | Train Loss: 0.02537 | Val Loss: 0.01815 | Val Accuracy: 0.00375


Epoch 52/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 52/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 52/200 | Train Loss: 0.02430 | Val Loss: 0.01858 | Val Accuracy: 0.00379


Epoch 53/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 53/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 53/200 | Train Loss: 0.02087 | Val Loss: 0.01814 | Val Accuracy: 0.00432


Epoch 54/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 54/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 54/200 | Train Loss: 0.02321 | Val Loss: 0.01768 | Val Accuracy: 0.00449


Epoch 55/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 55/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 55/200 | Train Loss: 0.01998 | Val Loss: 0.01781 | Val Accuracy: 0.00411


Epoch 56/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 56/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 56/200 | Train Loss: 0.02211 | Val Loss: 0.01797 | Val Accuracy: 0.00406


Epoch 57/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 57/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 57/200 | Train Loss: 0.02524 | Val Loss: 0.01766 | Val Accuracy: 0.00406


Epoch 58/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 58/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 58/200 | Train Loss: 0.02810 | Val Loss: 0.01807 | Val Accuracy: 0.00407


Epoch 59/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 59/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 59/200 | Train Loss: 0.02305 | Val Loss: 0.01840 | Val Accuracy: 0.00419


Epoch 60/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 60/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 60/200 | Train Loss: 0.02047 | Val Loss: 0.01729 | Val Accuracy: 0.00467


Epoch 61/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 61/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 61/200 | Train Loss: 0.01976 | Val Loss: 0.01701 | Val Accuracy: 0.00417


Epoch 62/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 62/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 62/200 | Train Loss: 0.02014 | Val Loss: 0.01747 | Val Accuracy: 0.00364


Epoch 63/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 63/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 63/200 | Train Loss: 0.01919 | Val Loss: 0.01701 | Val Accuracy: 0.00373


Epoch 64/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 64/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 64/200 | Train Loss: 0.02111 | Val Loss: 0.01675 | Val Accuracy: 0.00383


Epoch 65/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 65/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 65/200 | Train Loss: 0.01924 | Val Loss: 0.01692 | Val Accuracy: 0.00315


Epoch 66/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 66/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 66/200 | Train Loss: 0.02221 | Val Loss: 0.01702 | Val Accuracy: 0.00376


Epoch 67/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 67/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 67/200 | Train Loss: 0.02255 | Val Loss: 0.01700 | Val Accuracy: 0.00372


Epoch 68/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 68/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 68/200 | Train Loss: 0.01974 | Val Loss: 0.01690 | Val Accuracy: 0.00366


Epoch 69/200 | Training:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch 69/200 | Validation:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 69/200 | Train Loss: 0.01871 | Val Loss: 0.01684 | Val Accuracy: 0.00422
Early stopping at epoch 69...


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
samples,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▃▄▅▅▅▅▆▃▃▃▁▂▂▁▄▄▄▇▆▆▇▆▇▇▇▆▇▇▇▄▆▅▅▅█▃▄▁▄▆
val_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,69.0
samples,8832.0
train_loss,0.01871
val_accuracy,0.00422
val_loss,0.01684
