# Model Training Notebook
This notebook demonstrates how to preprocess genome sequences and train the mutation detection model in GeneFix AI.

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from app.models.detection_model import MutationDetectionModel
from app.data_pipeline.sequence_cleaner import SequenceCleaner
import numpy as np

ModuleNotFoundError: No module named 'app'

## Example Data
We use synthetic DNA sequences for demonstration.

In [None]:
sequences = ['ATCGATCGATCGATCG', 'TGCATGCATGCATGCA']
labels = [0, 1]  # 0: no mutation, 1: mutation (dummy labels)

## Preprocessing
Convert DNA sequences to one-hot encoded tensors.

In [None]:
cleaner = SequenceCleaner()
processed = [cleaner.preprocess_sequence(seq) for seq in sequences]
for i, enc in enumerate(processed):
    print(f"Sequence {i+1} shape:", enc.shape)

## Model Initialization
We use a CNN+LSTM model for mutation detection.

In [None]:
model = MutationDetectionModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

## Training Loop
Train the model for a few epochs on the example data.

In [None]:
for epoch in range(2):
    for seq, label in zip(processed, labels):
        tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
        output = model(tensor)
        loss = criterion(output, torch.tensor([label]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")

## Pro Training Tips
Enhance your model training with these best practices:

- Use more data: The more diverse and realistic your training data, the better your model will generalize.
- Use validation sets: Always split your data into training and validation sets to monitor for overfitting.
- Track metrics: Log loss, accuracy, and other relevant metrics for each epoch.
- Use callbacks: Implement early stopping or learning rate schedulers for efficient training.
- Save checkpoints: Regularly save your model weights so you can resume or select the best model.
- Experiment: Try different architectures, hyperparameters, and data augmentation strategies.
- Use GPU: If available, move your model and tensors to CUDA for faster training.

In [None]:
from sklearn.model_selection import train_test_split

# Example: split data (expand for real use)
train_seqs, val_seqs, train_labels, val_labels = train_test_split(
    sequences, labels, test_size=0.5, random_state=42)

# Preprocess
train_proc = [cleaner.preprocess_sequence(seq) for seq in train_seqs]
val_proc = [cleaner.preprocess_sequence(seq) for seq in val_seqs]


In [None]:
import copy

best_val_loss = float('inf')
best_model = None
train_losses, val_losses = [], []

for epoch in range(5):
    model.train()
    epoch_train_loss = 0
    for seq, label in zip(train_proc, train_labels):
        tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
        output = model(tensor)
        loss = criterion(output, torch.tensor([label]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
    train_losses.append(epoch_train_loss / len(train_proc))

    # Validation
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for seq, label in zip(val_proc, val_labels):
            tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
            output = model(tensor)
            loss = criterion(output, torch.tensor([label]))
            epoch_val_loss += loss.item()
    val_loss = epoch_val_loss / len(val_proc)
    val_losses.append(val_loss)
    print(f"Epoch {epoch}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model.state_dict())


In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In [None]:
# Save the best model
if best_model is not None:
    torch.save(best_model, 'best_mutation_detection_model.pt')
    print('Best model saved as best_mutation_detection_model.pt')

## Advanced Training: Learning Rate Scheduling and Early Stopping
Enhance your training with learning rate schedulers and early stopping to avoid overfitting and speed up convergence.

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
early_stop_patience = 3
no_improve_epochs = 0

for epoch in range(20):
    model.train()
    epoch_train_loss = 0
    for seq, label in zip(train_proc, train_labels):
        tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
        output = model(tensor)
        loss = criterion(output, torch.tensor([label]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
    train_losses.append(epoch_train_loss / len(train_proc))

    # Validation
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for seq, label in zip(val_proc, val_labels):
            tensor = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)
            output = model(tensor)
            loss = criterion(output, torch.tensor([label]))
            epoch_val_loss += loss.item()
    val_loss = epoch_val_loss / len(val_proc)
    val_losses.append(val_loss)
    print(f"Epoch {epoch}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_loss:.4f}")
    scheduler.step(val_loss)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model.state_dict())
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
    if no_improve_epochs >= early_stop_patience:
        print(f"Early stopping at epoch {epoch}!")
        break


In [None]:
import sys
import os
sys.path.append(os.path.abspath("../.."))
sys.path.append(os.path.abspath(".."))

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../.."))
sys.path.append(os.path.abspath(".."))