In [1]:
import h5py
import matplotlib.pyplot as plt
from matplotlib import cm
import os

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data.sampler import SubsetRandomSampler

In [2]:
# Temperature for the two consecutive timestamp
temperature_fields = []

# Folder Path
path = "solutions"
  
# Read text File  
def read_text_file(file_path):
    with h5py.File(file_path, 'r') as f:
        temperature_fields.append(f['temperature'][:])
        
        
# Iterate through all file
for file in os.listdir(path):
    file_path = f"{path}/{file}"
  
    # call read text file function
    read_text_file(file_path)
    #print(f"{file_path} is finished reading")

temperature_fields = np.asarray(temperature_fields)
print(temperature_fields.shape)

(100, 100, 201, 401)


In [3]:
#Parameters
n_epoch = 200
batch_size = 1
lr = 5e-5
betas = (0.9, 0.999)

In [4]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
    
print("Current device is ",device)

# make results determinstic
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Current device is  cpu


In [5]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.encoder = nn.Sequential( # 1x201x401 => 6x23x45
            nn.Conv2d(1, 3, stride=(3, 3), kernel_size=(5, 5), padding=2),
            nn.Tanh(),
            nn.Conv2d(3, 6, stride=(3, 3), kernel_size=(5, 5), padding=2),
            nn.Tanh(),
        )
        
    def forward(self, x):
        out=self.encoder(x)
        return out

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential( # 6x23x45 => 201x401
            nn.ConvTranspose2d(6, 3, stride=(3, 3), kernel_size=(5, 5), padding=(2,2)),
            nn.Tanh(),
            nn.ConvTranspose2d(3, 1, stride=(3, 3), kernel_size=(5, 5), padding=(1,0)),
        )
        

    def forward(self, x):
        out=self.decoder(x)
        return out

In [6]:
encoder_path = "2D_ConvAE_results/Conv2D_encoder_best_Gadi.pth"
decoder_path = "2D_ConvAE_results/Conv2D_decoder_best_Gadi.pth"

encoder = Encoder().to(device)
decoder = Decoder().to(device)
encoder.load_state_dict(torch.load(encoder_path, map_location=torch.device('cpu')))
decoder.load_state_dict(torch.load(decoder_path, map_location=torch.device('cpu')))

print("Encoder and Decoder loaded!")

Encoder and Decoder loaded!


In [7]:
# Customised Dataset class
class KMNIST(Dataset):
    
    def __init__(self, dataset):
        self.input = dataset[:,:50,:,:]
        self.output = dataset[:,50:,:,:]
        
    def __len__(self):
        return len(self.input)

    def __getitem__(self, index):
        input_item = self.input[index]
        output_item = self.output[index]
        
        return input_item, output_item

In [8]:
temperature_dataset = KMNIST(
    temperature_fields
)


testingAndValidation_split = 0.2
validation_split = 0.1

# Creating data indices for training, testing and validation splits
# Reference: https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets
temperature_dataset_size = len(temperature_dataset)
temperature_indices = list(range(temperature_dataset_size))

temperature_training_testing_split = int(np.floor(testingAndValidation_split * temperature_dataset_size))
temperature_testing_validation_split = int(np.floor(validation_split * temperature_dataset_size))

np.random.shuffle(temperature_indices)
temperature_train_indices, temperature_val_indices ,temperature_test_indices = temperature_indices[temperature_training_testing_split:], temperature_indices[:temperature_testing_validation_split], temperature_indices[temperature_testing_validation_split:temperature_training_testing_split] 

# Creating data samplers
temperature_train_sampler = SubsetRandomSampler(temperature_train_indices)
temperature_test_sampler = SubsetRandomSampler(temperature_test_indices)
temperature_valid_sampler = SubsetRandomSampler(temperature_val_indices)

train_loader = DataLoader(
    dataset=temperature_dataset,
    batch_size = batch_size,
    sampler=temperature_train_sampler,
)

test_loader = DataLoader(
    dataset=temperature_dataset,
    batch_size = batch_size,
    sampler=temperature_test_sampler,
)

validation_loader = DataLoader(
    dataset=temperature_dataset,
    batch_size = batch_size,
    sampler=temperature_valid_sampler,
)

In [9]:

for X, Y in test_loader:
    print(encoder(X.view(X.shape[1], 1, 201, 401)).reshape(batch_size,50,-1).shape)
    print(encoder(Y.view(Y.shape[1], 1, 201, 401)).shape)


torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])
torch.Size([1, 50, 6210])
torch.Size([50, 6, 23, 45])


In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()    
        self.lstm1 = nn.LSTM(input_size=6210, hidden_size=3105, num_layers=1, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=3105, hidden_size=6210, num_layers=1, batch_first=True)
        
    
    def forward(self, x):
        sequence_length = 50
        
        out, _ = self.lstm1(x)
        out, _ = self.lstm2(out)
        
        out = out.view(sequence_length,6,23,45)
        
        return out

In [11]:
def train(model, encoder, train_loader, val_loader, device, optimizer, n_epoch):
    
    criterion = nn.MSELoss()
    
    minimum_validation_loss = 10000000
    best_model_index = -1
    
    running_loss_list = []
    validation_loss_list = []

    # n_epoch times of iterations
    for epoch in range(n_epoch):

        running_loss = 0.0

        model.train()
        
        for data in train_loader:
            # get a batch of inputs and labels
            inputs, labels = data[0].to(device), data[1].to(device)
            encoded_inputs = encoder(inputs.view(inputs.shape[1], 1, 201, 401)).reshape(batch_size, 50, -1)
            encoded_labels = encoder(labels.view(labels.shape[1], 1, 201, 401))

            # zero the parameter gradients
            optimizer.zero_grad(set_to_none=True)

            # Get output features, calculate loss and optimize
            outputs = model(encoded_inputs)
            loss = criterion(outputs.float(), encoded_labels.float())
            
            loss.backward()
            optimizer.step()

            # Add to the total training loss
            running_loss += loss.item()
            print(loss.item())

        # print some statistics
        print(epoch+1,"epochs have finished")
        print("Current training loss is ",running_loss)
        running_loss_list.append(running_loss)
        running_loss = 0.0

        # Valiadation
        model.eval()
        with torch.no_grad():
            valid_loss = 0.0
            for data in val_loader:
                # get a batch of inputs and labels
                inputs, labels = data[0].to(device), data[1].to(device)
                encoded_inputs = encoder(inputs.view(inputs.shape[1], 1, 201, 401)).reshape(batch_size, 50, -1)
                encoded_labels = encoder(labels.view(labels.shape[1], 1, 201, 401))

                # Get output features, calculate loss and optimize
                outputs = model(encoded_inputs)
                loss = criterion(outputs.float(), encoded_labels.float())

                # Add to the validation loss
                valid_loss += loss.item()

            # Calculate valiadation accuracy and print Validation statistics
            print("Validation loss for this epoch is",valid_loss)
            validation_loss_list.append(valid_loss)

        # Update the statistics for the best model
        if valid_loss <= minimum_validation_loss:
            minimum_validation_loss = valid_loss

            # Store the best models
            PATH = 'lstm_best.pth'

            torch.save(model.state_dict(), PATH)
            print("This model is now saved to Path:",PATH)
            
            best_model_index = epoch
            
        print()
    
    # Training finished, print the statistics for the best model
    print('Finished Training')
    print("Best model has a validation loss of ",minimum_validation_loss)
    print("Best model is in epoch ",best_model_index+1)
    
    # Plot the Training loss and validation loss during training
    plt.figure(figsize=(12, 6))

    plt.subplot(2, 1, 1)
    plt.plot(running_loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss in each epoch')

    plt.subplot(2, 1, 2)
    plt.plot(validation_loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss in each epoch')
    
    plt.subplots_adjust(hspace=1)

    plt.show()

In [12]:
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

In [13]:
train(model, encoder, train_loader, validation_loader, device, optimizer, n_epoch)

0.2614653408527374


KeyboardInterrupt: 

In [None]:
def test(model, encoder, decoder, test_loader, device):

    # Load the model from the input model_path  
    model.load_state_dict(torch.load('lstm_best.pth', map_location=torch.device('cpu')))

    criterion = nn.MSELoss()
    total_loss = 0.0
    
    best_worst_error_list = [1000000, 0]
    best_worst_output_list = [0, 0]
    best_worst_predicted_list = [0, 0]
    
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            encoded_inputs = encoder(inputs.view(inputs.shape[1], 1, 201, 401)).reshape(batch_size, 50, -1)
            encoded_labels = encoder(labels.view(labels.shape[1], 1, 201, 401))

            # Get output features, calculate loss and optimize
            outputs = model(encoded_inputs)
            loss = criterion(outputs.float(), encoded_labels.float())

            for j in range(len(encoded_labels)):
                single_loss = criterion(outputs[j], encoded_labels[j])
                # Record worst error
                if single_loss.item() > best_worst_error_list[1]:
                    best_worst_error_list[1] = single_loss.item()
                    best_worst_output_list[1] = labels[0][j]
                    best_worst_predicted_list[1] = outputs[j]
                    
                # Record best error
                if single_loss.item() < best_worst_error_list[0]:
                    best_worst_error_list[0] = single_loss.item()
                    best_worst_output_list[0] = labels[0][j]
                    best_worst_predicted_list[0] = outputs[j]
                    

            # Add to the validation loss
            total_loss += loss.item()

    # Calculate the overall accuracy and return the accuracy and test loss
    print("Total loss for the model is",total_loss)
    print()
    
    # Draw some plots for the best and the worst error
    print("Best model has a error of ", best_worst_error_list[0])
    
    plt.figure(figsize=(18, 9))
    
    plt.subplot(2,2,1)
    plt.title("Best case output")
    plt.imshow(best_worst_output_list[0].detach().numpy(),
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    
    original_size_predicted = decoder(encoder(best_worst_output_list[0].view(1, 1, 201, 401)).view(1, 6, 23, 45))
    
    plt.subplot(2,2,2)
    plt.title("Best case AE output")
    plt.imshow(original_size_predicted.detach().numpy()[0][0],
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    
    
    latent_space_predicted = best_worst_predicted_list[0].view(1, 6, 23, 45)
    original_size_predicted = decoder(latent_space_predicted)
    
    plt.subplot(2,2,3)
    plt.title("Best case predicted output")
    plt.imshow(original_size_predicted.detach().numpy()[0][0],
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    plt.show()
    
    
    print("Worst model has a error of ", best_worst_error_list[1])
    
    
    plt.figure(figsize=(18, 9))
    
    plt.subplot(2,2,1)
    plt.title("Worst case output")
    plt.imshow(best_worst_output_list[1].detach().numpy(),
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    
    original_size_predicted = decoder(encoder(best_worst_output_list[1].view(1, 1, 201, 401)).view(1, 6, 23, 45))
    
    plt.subplot(2,2,2)
    plt.title("Worst case AE output")
    plt.imshow(original_size_predicted.detach().numpy()[0][0],
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    
    
    latent_space_predicted = best_worst_predicted_list[1].view(1, 6, 23, 45)
    original_size_predicted = decoder(latent_space_predicted)
    
    plt.subplot(2,2,3)
    plt.title("Worst case predicted output")
    plt.imshow(original_size_predicted.detach().numpy()[0][0],
              cmap=cm.get_cmap('jet', 10),
              extent=(0, 2, 0, 1))
    ax = plt.gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    
    plt.show()
    
    #return 100*correct//total, total_loss

In [None]:
test(model, encoder, decoder, test_loader, device)