<a href="https://colab.research.google.com/github/DeependraChaddha/Vlasov_Poisson_Solver/blob/main/vlasov_poisson_system_solver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Script for Vlasov Poisson System Solver

In [None]:
%%writefile vlasov_poisson_solver.py
import os
import torch
import torch.nn as nn
import torch.optim
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from typing import Callable
import tqdm
import pickle

class VlasovPoissonSolver:
    def __init__(self, nx: int, nv: int, nt: int, x_range: tuple, v_range: tuple, t_range: tuple, device: str):
        self.nx = nx
        self.nv = nv
        self.nt = nt
        self.x_range = x_range
        self.v_range = v_range
        self.t_range = t_range
        self.device = device if torch.cuda.is_available() else 'cpu'
        self.model_checkpoint_path = None
        print("Instance of Class VlasovPoissonSolver created.")

    def make_grid(self):
        """Creates a mesh grid for x, v, and t and enables gradient computation."""
        x = torch.linspace(self.x_range[0], self.x_range[1], self.nx).reshape(-1, 1)
        v = torch.linspace(self.v_range[0], self.v_range[1], self.nv).reshape(-1, 1)
        t = torch.linspace(self.t_range[0], self.t_range[1], self.nt).reshape(-1, 1)
        X, V, T = torch.meshgrid(x.squeeze(), v.squeeze(), t.squeeze(), indexing='ij')
        X, V, T = X.requires_grad_(True), V.requires_grad_(True), T.requires_grad_(True)
        self.X, self.V, self.T = X.to(self.device), V.to(self.device), T.to(self.device)
        return self.X, self.V, self.T

    def save_checkpoint(self, model, optimizer, loss, model_name, hyperparameters):
        """Saves the model state, optimizer state, loss, and hyperparameters to a checkpoint file."""
        self.checkpoint_dir = f"checkpoint_{model_name}_nx{self.nx}_nv{self.nv}_nt{self.nt}_epochs{hyperparameters['epochs']}"
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(self.checkpoint_dir, "model_checkpoint.pkl")
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
            "hyperparameters": hyperparameters
        }
        torch.save(checkpoint, checkpoint_path)
        self.model_checkpoint_path = checkpoint_path

    def train_step(self, model, X, V, T, loss_fn, optimizer, scheduler=None):
        """Performs a single training step including forward pass, loss computation, and optimization update."""
        prediction = model(X, V, T, self.nx, self.nv)
        loss = loss_fn(model=model, x=X, v=V, t=T, nx=self.nx, nv=self.nv)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step(loss.item()) if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) else scheduler.step()
        return X, V, T, prediction, loss.item(), model.state_dict()

    def train(self, model, loss_fn, optimizer, epochs, model_name, get_loss_curve=False, scheduler=None):
        """Trains the neural network model over multiple epochs and saves the best model based on loss."""
        try:
            X, V, T = self.X, self.V, self.T
            model = model.to(self.device)
            best_loss = float('inf')
            loss_values = []
            for epoch in tqdm.tqdm(range(epochs)):
                X, V, T, Z, loss, model_params = self.train_step(model, X, V, T, loss_fn, optimizer, scheduler)
                loss_values.append(loss)
                if loss < best_loss:
                    best_loss = loss
                    self.save_checkpoint(model, optimizer, loss, model_name, {"epochs": epochs})
            print(f"Best Loss Achieved: {best_loss}")
            if get_loss_curve:
                plt.figure(figsize=(8, 5))
                plt.plot(range(1, epochs + 1), loss_values, marker='o', linestyle='-', color='r', label="Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.title("Loss Curve Over Epochs")
                plt.legend()
                plt.grid()
                loss_curve_path = os.path.join(self.checkpoint_dir, "loss_curve.png")
                plt.savefig(loss_curve_path)
                print(f"Loss curve saved at {loss_curve_path}")
                plt.show()
            return self.model_checkpoint_path
        except AttributeError:
            print("Make grid first, Call make_grid method")

    def animate_final_prediction(self, model_class, model_checkpoint_path=None):
        """Creates an animation to visualize the final predicted distribution over time.
        If the train function was called before calling this function, it takes the trained
        model as default, otherwise, the path of the parameters of the model to be animated
        needs to be input as a string. """
        try:
            if model_checkpoint_path is None:
                if self.model_checkpoint_path is None:
                    raise AttributeError("Model checkpoint not found. Please train the model first by calling the train function or provide model path.")
                model_checkpoint_path = self.model_checkpoint_path

            checkpoint = torch.load(model_checkpoint_path, map_location=self.device)
            model = model_class().to(self.device)
            model.load_state_dict(checkpoint["model_state_dict"])
            model.eval()

            fig, ax = plt.subplots(figsize=(8, 6))

            def update(frame):
                ax.clear()
                with torch.no_grad():
                    Z = model(self.X, self.V, self.T[:, :, frame], self.nx, self.nv)
                c = ax.pcolormesh(self.X.detach().cpu().numpy(), self.V.detach().cpu().numpy(), Z.detach().cpu().numpy(), cmap="viridis")
                fig.colorbar(c, ax=ax)
                ax.set_xlabel("X-axis")
                ax.set_ylabel("V-axis")
                ax.set_title(f"Predicted Distribution at Time {frame}")

            ani = FuncAnimation(fig, update, frames=self.nt, repeat=False)
            plt.show()
        except AttributeError as e:
            print(e)


#Using above script