# Definitions

In [1]:
import os
import math
import re
import random
from enum import Enum
from typing import List, Tuple, DefaultDict, Set
from collections import defaultdict, Counter

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Config, GPT2LMHeadModel, AdamW
from tqdm.auto import tqdm
from Bio import SeqIO
import wandb


os.environ["WANDB_NOTEBOOK_NAME"] = 'GPT_dBG.ipynb'


class DatasetType(Enum):
    HMM = "HMM"
    VIRAL = "viral"

class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.HMM
        self.model_name: str = 'gpt2'
        self.sequence_length: int = 1000
        self.num_seqs: int = 500
        self.num_hidden_states: int = 500
        self.sparsity: float = 1.1
        self.protein: bool = True
        # self.stride: int = 1000
        # self.split_ratio: float = 0.5
        # self.file_path: str = "./data/viral.1.1.genomic.fna"
        # self.substrings_per_seq: int = 20
        # self.sequences_shuffle: bool = True
        self.train_bs: int = 64
        self.val_bs: int = 128
        self.n_embed: int = 512
        self.n_layer: int = 4
        self.n_head: int = 1
        self.lr: float = 1e-4
        self.weight_decay: float = 0.00
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 3
        self.print_every: int = 20
        self.save_model_every: int = 20

class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.VIRAL
        self.model_name: str = 'gpt2'
        self.file_path: str = "./data/viral.1.1.genomic.fna"
        self.sequence_length: int = 1000
        self.stride: int = 1000
        self.num_seqs: int = 100000
        self.split_ratio: float = 0.2
        self.substrings_per_seq: int = 20
        self.sequences_shuffle: bool = True
        self.protein: bool = True
        self.sparsity: float = 1.1
        self.num_hidden_states: int = 100
        self.train_bs: int = 16
        self.val_bs: int = 32
        self.n_embed: int = 512
        self.n_layer: int = 6
        self.n_head: int = 1
        self.lr: float = 1e-4
        self.weight_decay: float = 0.00
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 5
        self.print_every: int = 20
        self.save_model_every: int = 20
        
# DNA codon table
dna_codon_table = {
    'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
    'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
    'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
    'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
    'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
    'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
    'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
    'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}


def dna2protein(dna_sequence):
    protein_dna = ''.join(dna_codon_table.get(dna_sequence[i:i+3], 'X') for i in range(0, len(dna_sequence), 3))
    return protein_dna

class SequenceTokenizer:
    def __init__(self, protein: bool = False):
        if protein:
            self.alphabet = list(dna_codon_table.values())+['X']
            self.alphabet = set(self.alphabet)
        else:
            self.alphabet = {'A', 'C', 'G', 'T', 'X'}
        self.token_to_idx = {char: i for i, char in enumerate(self.alphabet)}
        self.idx_to_token = {i: char for i, char in enumerate(self.alphabet)}
        self.vocab_size = len(self.token_to_idx)

    def encode(self, sequence: str, return_tensors: str = "pt") -> torch.Tensor:
        tokens = [self.token_to_idx[char] for char in sequence]
        tokens = torch.tensor(tokens, dtype=torch.long)
        return tokens

    def decode(self, tokens: torch.Tensor) -> str:
        sequence = [self.idx_to_token[int(token.item())] for token in tokens]
        return ''.join(sequence)


class DNADataset(Dataset):
    def __init__(self, sequences: List[str], tokenizer: SequenceTokenizer, protein: bool = False):
        self.sequences = sequences
        self.tokenizer = tokenizer
        if protein:
            self.diff = 3
        else:
            self.diff = 1

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, i):
        sequence = self.sequences[i]
        inputs = self.tokenizer.encode(sequence, return_tensors='pt')
        targets = inputs[self.diff:].clone()
        inputs = inputs[:-self.diff]
        return inputs, targets


def compute_char_probabilities(sequences: List[str]) -> None:
    counts = defaultdict(int)
    total_count = 0
    for seq in sequences:
        cd = Counter(seq)
        for c, count in cd.items():
            counts[c] += count
            total_count += count
    pcounts = {char: c / total_count for char, c in counts.items()}
    print(pcounts)

def calculate_entropy(sequences, k):
    substrings = [seq[i:i+k] for seq in sequences for i in range(len(seq) - k + 1)]
    frequencies = Counter(substrings)
    entropy = 0.0
    total = sum(frequencies.values())
    probabilities = {k: v / total for k, v in frequencies.items()}
    for p in probabilities.values():
        entropy -= p * math.log2(p)
    print(probabilities)
    print('entropy = ', entropy/k)
 

def read_fna(file_path: str, shuffle: bool = False) -> List[str]:
    sequences = []
    with open(file_path, "r") as f:
        for record in SeqIO.parse(f, "fasta"):
            sequences.append(str(record.seq))
    if shuffle:
        random.shuffle(sequences)
    return sequences


def generate_markov_chain(n: int, sparsity: float) -> torch.Tensor:
    transition_probs = torch.rand(n, n)
    for i in range(n):
        num_outgoing_states = torch.poisson(torch.tensor([float(sparsity)])).int().item()
        num_outgoing_states = min(n, max(1, num_outgoing_states))
        _, indices = torch.topk(transition_probs[i], int(num_outgoing_states))
        mask = torch.zeros_like(transition_probs[i]).scatter_(0, indices, 1).to(torch.bool)
        transition_probs[i] *= mask
    transition_probs = F.normalize(transition_probs, p=1, dim=1)
    return transition_probs


def draw_seq(transition_probs: torch.Tensor, sequence_length: int) -> str:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    current_state = 0
    chain = [current_state]
    while len(chain) < sequence_length:
        next_state = torch.multinomial(transition_probs[int(current_state)], num_samples=1)
        current_state = next_state.item()
        chain.append(int(current_state))
    seq = [NUCLEOTIDES[s % len(NUCLEOTIDES)] for s in chain]
    return ''.join(seq)


def generate_HMM_dataset(sequence_length: int, N: int, sparsity: float, num_hidden_states: int) -> List[List[str]]:
    if num_hidden_states is None:
        num_hidden_states = sequence_length
    transition_matrix = generate_markov_chain(num_hidden_states, sparsity=sparsity)
    all_seqs = [[], []]
    for i in range(2):
        seqs = []
        while len(seqs) < N:
            seq = draw_seq(transition_matrix, sequence_length)
            if len(seq) >= sequence_length:
                seqs.append(seq)
        all_seqs[i] = seqs
    return all_seqs


def generate_phylo_dataset(sequence_length: int, mutation_rate: float, N: int) -> List[str]:
    NUCLEOTIDES = ['A', 'C', 'G', 'T']
    parent_sequence = ''.join(random.choice(NUCLEOTIDES) for _ in range(sequence_length))
    dataset = [parent_sequence]
    for _ in range(N):
        mutated_sequence = ''
        for nucleotide in parent_sequence:
            if random.random() < mutation_rate:
                mutated_sequence += random.choice(NUCLEOTIDES)
            else:
                mutated_sequence += nucleotide
        dataset.append(mutated_sequence)
        parent_sequence = mutated_sequence
    return dataset


def construct_debruijn_graph(dataset: List[str], k: int) -> DefaultDict[str,Set[str]]:
    graph = defaultdict(set)
    for sequence in dataset:
        for i in range(len(sequence) - k):
            graph[sequence[i:i + k]].add(sequence[i + 1:i + k + 1])
    return graph


def extract_substrings(sequences: List[str], sequence_length: int, stride: int, substrings_per_seq: int, protein: bool = False) -> List[str]:
    substrings = []
    for sequence in sequences:
        if not bool(re.match("^[ACGT]+$", sequence)):
            continue
        if protein:
            sequence = dna2protein(sequence)
        for i in range(0, len(sequence) - sequence_length + 1, stride):
            if i // stride > substrings_per_seq:
                break
            seq = sequence[i:i + sequence_length]
            substrings.append(seq)
    return substrings


def set_random_seed(seed: int) -> None:
    torch.manual_seed(seed)
    random.seed(seed)


def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, device: torch.device, description: str, print_every: int) -> float:
    model.train()
    running_loss = []
    bar = tqdm(train_loader, desc=description)
    criterion = torch.nn.CrossEntropyLoss()
    for i, (inputs, targets) in enumerate(bar):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs, labels=targets)
        # loss = outputs.loss.mean()
        loss = criterion(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss.append(loss.item())
        if (i+1) % print_every == 0:
            avg_loss = sum(running_loss) / len(running_loss)
            bar.set_postfix({"Train Loss": f"{avg_loss:.5f}"})
            running_loss = []
    return sum(running_loss) / len(running_loss)


def evaluate(model: torch.nn.Module, val_loader: DataLoader, device: torch.device, description: str, print_every: int) -> Tuple[float, float]:
    model.eval()
    total_loss = []
    running_loss = []
    total_acc = []
    bar = tqdm(val_loader, desc=description)
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for i,(inputs, targets) in enumerate(bar):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs, labels=targets)
            val_loss = criterion(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1))
            predictions = torch.argmax(outputs.logits, dim=-1)
            accuracy = (predictions.int() == targets.int()).float().mean().item()
            total_loss.append(val_loss.item())
            running_loss.append(val_loss.item())
            total_acc.append(accuracy)
            if (i + 1) % print_every  == 0:
                avg_loss = sum(running_loss) / len(running_loss)
                bar.set_postfix({"Val Loss": f"{avg_loss:.5f}", "Val Accuracy": accuracy})
                running_loss = []
                # bar.set_postfix({"Val Loss": val_loss.item(), "Val Accuracy": accuracy})
    avg_val_loss = sum(total_loss) / len(total_loss)
    avg_accuracy = sum(total_acc) / len(total_acc)
    return avg_val_loss, avg_accuracy


def train_loop(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: DataLoader, val_loader: DataLoader, device: torch.device, config: Config) -> None:
    best_val_loss = float('inf')
    num_epochs_no_improve = 0  # Number of epochs with no improvement in validation loss

    try:
        for epoch in range(config.num_epochs):
            train_loss = train(model, optimizer, train_loader, device, f"Epoch {epoch + 1}/{config.num_epochs} | Training", config.print_every)
            val_loss, val_acc = evaluate(model, val_loader, device, f"Epoch {epoch + 1}/{config.num_epochs} | Validation", config.print_every)
            samples = (epoch+1) * len(train_loader) * config.train_bs
            print(f"Epoch {epoch + 1}/{config.num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f} | Val Accuracy: {val_acc:.5f}")
            wandb.log({"epoch": epoch + 1, "samples": samples, "train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_acc})
            
            if (epoch + 1) % config.save_model_every == 0:
                # Save the model weights as an artifact every 10 epochs
                artifact = wandb.Artifact(f"model_weights", type='model')
                torch.save(model, "gpt2_dna.pt")
                # torch.save(model.state_dict(), 'gpt2_dna.pt')
                artifact.add_file('gpt2_dna.pt')
                wandb.log_artifact(artifact)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                num_epochs_no_improve = 0
            else:
                num_epochs_no_improve += 1
                if num_epochs_no_improve >= config.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch + 1}...")
                    break

    except KeyboardInterrupt:
        print("Interrupted by user")
    finally:
        # torch.save(model.state_dict(), 'gpt2_dna.pt')
        torch.save(model, "gpt2_dna.pt")
        wandb.finish()



def load_datasets(config: Config) -> Tuple[List[str], List[str]]:
    if config.dataset == DatasetType.HMM:
        train_seqs, val_seqs = generate_HMM_dataset(config.sequence_length, N=config.num_seqs, sparsity=config.sparsity,
                                                    num_hidden_states=config.num_hidden_states)
    elif config.dataset == DatasetType.VIRAL:
        set_random_seed(42)
        sequences = read_fna(file_path=config.file_path, shuffle=config.sequences_shuffle)
        sequences = sequences[:config.num_seqs]
        sub_seqs = extract_substrings(sequences, sequence_length=config.sequence_length, stride=config.stride,
                                      substrings_per_seq=config.substrings_per_seq, protein=config.protein)
        # compute_char_probabilities(sub_seqs)
        train_size = int(len(sub_seqs) * (1 - config.split_ratio))
        train_seqs = sub_seqs[:train_size]
        val_seqs = sub_seqs[train_size:]
        train_seqs = train_seqs[:config.num_seqs]
        val_seqs = val_seqs[:config.num_seqs]
        calculate_entropy(train_seqs, k=2)
        calculate_entropy(train_seqs, k=2)
    return train_seqs, val_seqs



def main(config: Config) -> None:
    train_seqs, val_seqs = load_datasets(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = 'cpu'

    tokenizer = SequenceTokenizer(protein=config.protein, )
    train_dataset = DNADataset(train_seqs, tokenizer, protein=config.protein)
    val_dataset = DNADataset(val_seqs, tokenizer, protein=config.protein)

    train_loader = DataLoader(train_dataset, batch_size=config.train_bs, drop_last=True,shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.val_bs, drop_last=True,shuffle=False)
    name = f"{config.dataset.name}_{'protein' if config.protein else 'DNA'}__lr={config.lr}_bs={config.train_bs}_n_embed={config.n_embed}_n_head={config.n_head}_n_layer={config.n_layer}_early_stopping={config.early_stopping_patience}_num_seqs={config.num_seqs}_sequence_length={config.sequence_length}_stride={config.stride}"
    wandb.init(project='GPT2_DNA', name=name, config=config)

    gpt2_config = GPT2Config(vocab_size=tokenizer.vocab_size,
                             n_positions=config.sequence_length,
                             n_ctx=config.sequence_length,
                             n_embd=config.n_embed,
                             n_layer=config.n_layer,
                             n_head=config.n_head)

    model = GPT2LMHeadModel(gpt2_config).to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training.")
        model = torch.nn.DataParallel(model)

    optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    train_loop(model, optimizer, train_loader, val_loader, device, config)



if __name__ == "__main__":
    config = Config()
    main(config)

{'RA': 0.006065065065065065, 'AI': 0.0027672672672672672, 'IG': 0.002278978978978979, 'GA': 0.004077277277277277, 'AP': 0.0036585585585585585, 'PI': 0.0020240240240240242, 'IT': 0.0029721721721721723, 'TR': 0.005290590590590591, 'R*': 0.0043785785785785784, '*I': 0.002543743743743744, 'II': 0.0032987987987987987, 'IK': 0.002590790790790791, 'K*': 0.0018447447447447448, '*L': 0.003723323323323323, 'LT': 0.004916216216216216, 'RN': 0.0029266266266266268, 'NN': 0.0015734734734734735, 'NV': 0.001983983983983984, 'VS': 0.004471671671671671, 'SL': 0.007213113113113113, 'LF': 0.0036776776776776777, 'F*': 0.001904904904904905, '*R': 0.004263963963963964, 'RD': 0.0026816816816816816, 'DR': 0.003032932932932933, 'RI': 0.003704004004004004, 'IF': 0.0025684684684684684, 'FP': 0.0015882882882882882, 'PF': 0.0016730730730730732, 'FF': 0.0020525525525525523, 'FK': 0.00166986986986987, 'KE': 0.0017528528528528528, 'EL': 0.0029997997997997998, 'LK': 0.004083783783783784, 'KQ': 0.0015834834834834834, 'Q

[34m[1mwandb[0m: Currently logged in as: [33mamirjoudaki[0m ([33msketch-bros[0m). Use [1m`wandb login --relogin`[0m to force relogin


Using 2 GPUs for training.




Epoch 1/200 | Training:   0%|          | 0/625 [00:00<?, ?it/s]



Epoch 1/200 | Validation:   0%|          | 0/312 [00:00<?, ?it/s]

Epoch 1/200 | Train Loss: 2.85786 | Val Loss: 2.84106 | Val Accuracy: 0.12670


Epoch 2/200 | Training:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2/200 | Validation:   0%|          | 0/312 [00:00<?, ?it/s]

Epoch 2/200 | Train Loss: 2.80991 | Val Loss: 2.82742 | Val Accuracy: 0.13002


Epoch 3/200 | Training:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3/200 | Validation:   0%|          | 0/312 [00:00<?, ?it/s]

Epoch 3/200 | Train Loss: 2.84227 | Val Loss: 2.82490 | Val Accuracy: 0.13098


Epoch 4/200 | Training:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 4/200 | Validation:   0%|          | 0/312 [00:00<?, ?it/s]

Epoch 4/200 | Train Loss: 2.81720 | Val Loss: 2.82303 | Val Accuracy: 0.13084


Epoch 5/200 | Training:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 5/200 | Validation:   0%|          | 0/312 [00:00<?, ?it/s]

In [1]:
def get_model(config: Config) -> None:
    train_seqs, val_seqs = load_datasets(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = SequenceTokenizer()
    train_dataset = DNADataset(train_seqs, tokenizer)
    val_dataset = DNADataset(val_seqs, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=config.train_bs, drop_last=True,shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.val_bs, drop_last=True,shuffle=False)

    model = torch.load('gpt2_dna.pt')

    return model, train_loader, val_loader, device

model, train_loader, val_loader, device = get_model(Config())

inputs, labels = next(iter(val_loader))
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs, labels=labels)

predicitions = torch.argmax(outputs.logits, dim=-1)
tokenizer = SequenceTokenizer()
i = 2
acc = (predicitions[i] == labels[i]).float().mean().item()
print(acc)
x = tokenizer.decode(inputs[i])
y = tokenizer.decode(predicitions[i])


x, y

NameError: name 'Config' is not defined

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

class Config:
    def __init__(self):
        self.dataset: DatasetType = DatasetType.VIRAL
        self.model_name: str = 'gpt2'
        self.file_path: str = "./data/viral.1.1.genomic.fna"
        self.sequence_length: int = 1000
        self.stride: int = 1000
        self.num_seqs: int = 1000
        self.split_ratio: float = 0.5
        self.substrings_per_seq: int = 20
        self.sequences_shuffle: bool = True
        self.protein: bool = False
        # self.sparsity: float = 1.1
        # self.num_hidden_states: int = 100
        self.train_bs: int = 32
        self.val_bs: int = 256
        self.n_embed: int = 1024
        self.n_layer: int = 3
        self.n_head: int = 1
        self.lr: float = 0.005
        self.weight_decay: float = 0.00
        self.num_epochs: int = 200
        self.early_stopping_patience: int = 5
        self.print_every: int = 20
        self.save_model_every: int = 20

# Define the LSTM model
class DNAPredictor(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(DNAPredictor, self).__init__()
        self.hidden_size = hidden_size
        self.one_hot = lambda x: torch.functional.F.one_hot(x, num_classes=output_size).float()
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, output_size)
    
    def forward(self, input_seq):
        input_seq = self.one_hot(input_seq)
        lstm_out, _ = self.lstm(input_seq)
        output = self.fc(lstm_out)
        return output
    

config = Config()
train_seqs, val_seqs = load_datasets(config)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

tokenizer = SequenceTokenizer()
train_dataset = DNADataset(train_seqs, tokenizer)
val_dataset = DNADataset(val_seqs, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=config.train_bs, drop_last=True,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.val_bs, drop_last=True,shuffle=False)

output_size = tokenizer.vocab_size
# Initialize the model
model = DNAPredictor(tokenizer.vocab_size, config.n_embed, tokenizer.vocab_size, num_layers=config.n_layer).to(device)
criteria = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.lr)

# Train the model
for epoch in range(config.num_epochs):
    total_loss = []
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Forward pass
        outputs = model(inputs)
        loss = criteria(outputs.view(-1,outputs.shape[-1]), labels.flatten())
        total_loss.append(loss.item())
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    train_loss = sum(total_loss)/len(total_loss)
    # print(f'Epoch [{epoch+1}/{config.num_epochs}], Train Loss: {train_loss:.4f}')

    # evaluate
    with torch.no_grad():
        total_loss = []
        for i, (inputs, labels) in enumerate(val_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(inputs)
            loss = criteria(outputs.view(-1,outputs.shape[-1]), labels.flatten())
            total_loss.append(loss.item())
        val_loss = sum(total_loss)/len(total_loss)
        print(f'Epoch [{epoch+1}/{config.num_epochs}],  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')


{'AG': 0.060456456456456455, 'GA': 0.06808208208208208, 'GC': 0.06076976976976977, 'CA': 0.06842842842842843, 'AA': 0.08358058058058059, 'AT': 0.0681901901901902, 'TC': 0.0603003003003003, 'CG': 0.057458458458458456, 'GG': 0.05754654654654655, 'GT': 0.054285285285285284, 'TG': 0.06524224224224225, 'CC': 0.05598198198198198, 'TT': 0.07095595595595595, 'TA': 0.0526036036036036, 'AC': 0.06046046046046046, 'CT': 0.05565765765765766}
entropy =  1.9947617887039697
{'AG': 0.060456456456456455, 'GA': 0.06808208208208208, 'GC': 0.06076976976976977, 'CA': 0.06842842842842843, 'AA': 0.08358058058058059, 'AT': 0.0681901901901902, 'TC': 0.0603003003003003, 'CG': 0.057458458458458456, 'GG': 0.05754654654654655, 'GT': 0.054285285285285284, 'TG': 0.06524224224224225, 'CC': 0.05598198198198198, 'TT': 0.07095595595595595, 'TA': 0.0526036036036036, 'AC': 0.06046046046046046, 'CT': 0.05565765765765766}
entropy =  1.9947617887039697
Epoch [1/200],  Train Loss: 1.6878, Val Loss: 1.3884
Epoch [2/200],  Train