In [1]:
from torch.utils.data import Dataset
import torch
import numpy as np
import h5py
import pytorch_lightning as pl
from torch import nn
import matplotlib.pyplot as plt
# import models


class model_wraper(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        module = __import__("models")
        self.model = getattr(module, config["MODEL_NAME"])(config)
        self.criterion_mse = nn.MSELoss()
        self.config = config
        self.val_pred = []
        self.val_loss = []

    def forward(self, batch):
        return self.model(batch)

    def training_step(self, batch, batch_idx):
        pred = self.forward(batch[0])
        target = batch[1]
        loss = self.criterion_mse(pred, target)
        return loss

    def validation_step(self, batch, batch_idx):
        pred = self.forward(batch[0])
        target = batch[1]
        loss = self.criterion_mse(pred, target)
        self.val_pred.append([target, pred])
        self.val_loss.append(loss)
        return loss

    # def on_validation_epoch_end(self):
    #     pass

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.config['learning_rate'], weight_decay=self.config['weight_decay'])
        return optimizer
    
    def load_model_state(self, PATH):
        checkpoint = torch.load(PATH, map_location='cuda:0')
        self.model.load_state_dict(checkpoint['state_dict'])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch import nn
import matplotlib.pyplot as plt
import load_data as ld
import models

config = {}
config["PATH_TRAIN"] = "../data/batch1.hdf5"
data_set_conv = ld.Dataset_baseline_conv(config)

config["MODEL_NAME"] = "auto_encoder_conv"
config["in_dim"] = data_set_conv.data_in.shape[1]
config["batch_size"] = 1
config["learning_rate"] = 1e-4
config["weight_decay"] = 0
config["embedding_dim"] = 128 #int(config["in_dim"]/2)
config["hidden1_dim"] = int(config["embedding_dim"]/2)
config["hidden2_dim"] = int(config["embedding_dim"]/4)
config["encoder_dim"] = int(config["embedding_dim"]/8)
PATH = "/gpfs/data/fs72150/springerd/Projects/LuttingerWard_from_ML/saves/save_auto_encoder_conv_2023-12-11/version_0/checkpoints/epoch=99-step=120000.ckpt"
# checkpoint = torch.load(PATH, map_location='cuda:0')
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model_conv = model_wraper(config)
model_conv.load_state_dict(checkpoint['state_dict'])

################
config = {}
config["PATH_TRAIN"] = "../data/batch1.hdf5"
data_set = ld.Dataset_baseline(config)

config["MODEL_NAME"] = "auto_encoder"
config["in_dim"] = data_set.data_in.shape[1]
config["batch_size"] = 1
config["learning_rate"] = 1e-4
config["weight_decay"] = 0
config["embedding_dim"] = 128 #int(config["in_dim"]/2)
config["hidden1_dim"] = int(config["embedding_dim"]/2)
config["hidden2_dim"] = int(config["embedding_dim"]/4)
config["encoder_dim"] = int(config["embedding_dim"]/8)
PATH = "/gpfs/data/fs72150/springerd/Projects/LuttingerWard_from_ML/saves/save_auto_encoder_2023-12-11/version_0/checkpoints/epoch=99-step=120000.ckpt"
# checkpoint = torch.load(PATH, map_location='cuda:0')
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model = model_wraper(config)
model.load_state_dict(checkpoint['state_dict'])

#################
config = {}
config["PATH_TRAIN"] = "../data/batch1.hdf5"
data_set_graph = ld.Dataloader_graph(config)

config["MODEL_NAME"] = "GreenGNN"
config["batch_size"] = 1
config["learning_rate"] = 1e-4
config["weight_decay"] = 0
PATH = "/gpfs/data/fs72150/springerd/Projects/LuttingerWard_from_ML/saves/save_GreenGNN_2023-12-14/version_0/checkpoints/epoch=19-step=200000.ckpt"
# PATH = "save_GreenGNN_2023-12-14/version_18/checkpoints/epoch=1-step=240000.ckpt"
# checkpoint = torch.load(PATH, map_location='cuda:0')
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model_graph = models.model_wraper(config)
model_graph.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [10]:
random_sample = int(np.random.rand()*len(data_set))

in_sample = data_set[random_sample][0]
in_sample_conv = data_set_conv[random_sample][0]
in_sample_graph = data_set_graph[random_sample]
in_sample_graph_idx = {}
in_sample_graph_idx["edge_index"] = in_sample_graph["edge_index"][None]
in_sample_graph_idx["node_feature"] = in_sample_graph["node_feature"][None]
in_sample_graph_idx["vectors"] = in_sample_graph["vectors"][None]
target_conv = data_set_conv[random_sample][1]
target = data_set[random_sample][1]
target_graph = data_set_graph[random_sample]["target"]
prediction = model.model(in_sample)
prediction_conv = model_conv.model(in_sample_conv[None,:,:])
prediction_graph = model_graph.model(in_sample_graph_idx)

with plt.rc_context({'axes.edgecolor':'black', 'xtick.color':'black', 'ytick.color':'black', 'figure.facecolor':'white'}):
    plt.figure(int(1000*np.random.rand()))
    # plt.plot(target_conv[100:199].cpu())
    plt.plot(target[100:199].cpu(), label="Target")
    plt.plot(prediction[100:199].detach().numpy(), label="Linear AE")
    plt.plot(prediction_conv[0,100:199].detach().numpy(), label="Conv AE")
    plt.plot(prediction_graph[:].detach().numpy(), label="Graph")
    plt.legend()
