In [None]:
import os
from dataclasses import dataclass
from typing import Tuple

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchdiffeq import odeint


@dataclass
class TrainingConfig:
    batch_size: int = 32
    learning_rate: float = 0.001
    max_epochs: int = 100
    patience: int = 10
    grad_clip: float = 1.0
    noise_std: float = 0.05
    hidden_size: int = 256
    train_split: float = 0.8
    checkpoint_dir: str = "checkpoints"
    
class ODEDynamics(nn.Module):
    def __init__(self, hidden_size: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, hidden_size//2),
            nn.LayerNorm(hidden_size//2),
            nn.ReLU(),
            nn.Linear(hidden_size//2, 3)
        )
        self._init_weights()
        
    def _init_weights(self):
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, t, y):
        return self.net(y)

class EarlyStopping:
    def __init__(self, patience: int = 10, min_delta: float = 1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False

    def __call__(self, val_loss: float) -> bool:
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.should_stop

class ODETrainer:
    def __init__(self, config: TrainingConfig, device: torch.device):
        self.config = config
        self.device = device
        self.writer = SummaryWriter(log_dir="runs/ODEExperiment")
        os.makedirs(config.checkpoint_dir, exist_ok=True)
        
    def generate_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
        t = torch.linspace(0, 10, 1000).to(self.device)
        true_y = torch.zeros(1000, 3).to(self.device)
        
        # More complex dynamics
        true_y[:, 0] = torch.sin(t) + 0.5 * torch.sin(2*t) + self.config.noise_std * torch.randn_like(t)
        true_y[:, 1] = torch.cos(t) + 0.5 * torch.cos(3*t) + self.config.noise_std * torch.randn_like(t)
        true_y[:, 2] = 0.5 * torch.sin(t) * torch.cos(t) + self.config.noise_std * torch.randn_like(t)
        
        return t, true_y

    def prepare_dataloaders(self, t: torch.Tensor, true_y: torch.Tensor):
        train_size = int(self.config.train_split * len(t))
        train_t, val_t = torch.split(t, [train_size, len(t) - train_size])
        train_y, val_y = torch.split(true_y, [train_size, len(t) - train_size])
        
        train_dataset = torch.utils.data.TensorDataset(train_t, train_y)
        val_dataset = torch.utils.data.TensorDataset(val_t, val_y)
        
        return (
            torch.utils.data.DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True),
            torch.utils.data.DataLoader(val_dataset, batch_size=self.config.batch_size)
        )

    def train_model(self, solver: str):
        model = ODEDynamics(self.config.hidden_size).to(self.device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.config.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
        loss_fn = nn.MSELoss()
        early_stopping = EarlyStopping(patience=self.config.patience)
        
        t, true_y = self.generate_data()
        train_loader, val_loader = self.prepare_dataloaders(t, true_y)
        
        best_val_loss = float('inf')
        
        for epoch in range(self.config.max_epochs):
            train_loss = self._train_epoch(model, train_loader, optimizer, loss_fn, solver)
            val_loss = self._validate_epoch(model, val_loader, loss_fn, solver)
            
            scheduler.step(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self._save_checkpoint(model, optimizer, epoch, val_loss, solver)
            
            self._log_metrics(train_loss, val_loss, model, true_y, t, solver, epoch)
            
            if early_stopping(val_loss):
                print(f"Early stopping triggered after {epoch} epochs")
                break
                
        return model, best_val_loss

    def _train_epoch(self, model, train_loader, optimizer, loss_fn, solver) -> float:
        model.train()
        total_loss = 0
        
        for t_batch, y_batch in train_loader:
            optimizer.zero_grad()
            pred_y = odeint(model, y_batch[0], t_batch, method=solver).to(self.device)
            loss = loss_fn(pred_y, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip)
            optimizer.step()
            total_loss += loss.item()
            
        return total_loss / len(train_loader)

    def _validate_epoch(self, model, val_loader, loss_fn, solver) -> float:
        model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for t_val, y_val in val_loader:
                pred_val = odeint(model, y_val[0], t_val, method=solver).to(self.device)
                loss = loss_fn(pred_val, y_val)
                total_loss += loss.item()
                
        return total_loss / len(val_loader)

    def _save_checkpoint(self, model, optimizer, epoch, val_loss, solver):
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'val_loss': val_loss
        }
        path = os.path.join(self.config.checkpoint_dir, f'model_{solver}_best.pt')
        torch.save(checkpoint, path)

    def _log_metrics(self, train_loss, val_loss, model, true_y, t, solver, epoch):
        self.writer.add_scalar(f"{solver}/Train Loss", train_loss, epoch)
        self.writer.add_scalar(f"{solver}/Validation Loss", val_loss, epoch)
        
        if epoch % 5 == 0:
            self._visualize_predictions(model, true_y, t, solver, epoch)

    def _visualize_predictions(self, model, true_y, t, solver, epoch):
        model.eval()
        with torch.no_grad():
            pred_y = odeint(model, true_y[0], t, method=solver)
            
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')
        ax.plot(true_y[:, 0].cpu(), true_y[:, 1].cpu(), true_y[:, 2].cpu(), 
                label='True', linewidth=2, alpha=0.8)
        ax.plot(pred_y[:, 0].cpu(), pred_y[:, 1].cpu(), pred_y[:, 2].cpu(), 
                label='Predicted', linewidth=2, alpha=0.8)
        ax.set_title(f"{solver} - Epoch {epoch}")
        ax.legend()
        self.writer.add_figure(f"{solver}/Trajectories", fig, epoch)
        plt.close(fig)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config = TrainingConfig()
    trainer = ODETrainer(config, device)
    
    solvers = ['dopri5', 'rk4', 'implicit_adams']
    results = {}
    
    for solver in solvers:
        print(f"\nTraining with {solver} solver...")
        model, val_loss = trainer.train_model(solver)
        results[solver] = {'model': model, 'val_loss': val_loss}
        print(f"Best validation loss with {solver}: {val_loss:.6f}")
    
    trainer.writer.close()
    
if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'torch'

In [None]:
!conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y