In [1]:
import numpy as np
import webdataset as wds
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms 
import os
import random
from tqdm import tqdm

PATH_TO_DATA = "/mnt/analysis/analysis/rand_sharded_data/" 

In [2]:
image_normalize = transforms.Normalize(
                  mean=[0.17960437768666657, 0.14584139607643212, 0.10744440357398845, 0.2583671063835548],
                  std=[0.059635202669355195, 0.04059554002618016, 0.03371736326989986, 0.06295501902505744]
)

forcing_normalize = transforms.Normalize(
                  mean=[444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119],
                  std=[5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]
)

forcing_mean = torch.from_numpy(np.array([444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119]))
forcing_std = torch.from_numpy(np.array([5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]))

lst_mean = torch.from_numpy(np.array([312.8291360088677]))
lst_std = torch.from_numpy(np.array([11.376636496297289]))

In [3]:
def create_train_test(path_to_data, train_perc, test_perc):
    files = []
    for dirpath, dirnames, filenames in os.walk(path_to_data):
        files.extend(filenames)
    
    saturated = files[:-1]
    unsaturated = files[-1]
    
    dataset = wds.WebDataset(path_to_data + "/" + unsaturated)
    counter = 0
    for data in dataset:
        counter += 1
    
    total_files = counter + len(saturated) * 10000
    training_data = total_files * train_perc //10000
    test_data_files = total_files * test_perc //10000

    training_data = random.sample(files, int(training_data))
    test_data = [file for file in files if file not in training_data]
    test_data = random.sample(test_data, int(test_data_files))
    # Get sample sizes of train and test data
    training_samples = 0
    testing_samples = 0
    
    for path in training_data:
        if path in saturated:
            training_samples += 10000
        elif path in unsaturated:
            training_samples += counter
            
    for path in test_data:
        if path in saturated:
            testing_samples += 10000
        elif path in unsaturated:
            testing_samples += counter
            
            
    # Convert to filename lists 
    training_filepath = []
    for dat in training_data:
        training_filepath.append(dat[6:12])
    training_path = path_to_data + "shard-" + "{" + ",".join(training_filepath) + "}" + ".tar"
    
    testing_filepath = []
    for dat in test_data:
        testing_filepath.append(dat[6:12])
    testing_path = path_to_data + "shard-{" + ",".join(testing_filepath) +"}.tar"
    train_data = wds.WebDataset(training_path).shuffle(10000, initial=10000).decode("rgb").rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
    test_data = wds.WebDataset(testing_path).decode("rgb").shuffle(10000, initial=10000).rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
            
    return (train_data, training_samples), (test_data, testing_samples)
    
(train_data, training_samples_len), (test_data, testing_samples_len) = create_train_test(PATH_TO_DATA, 0.85, 0.15)


In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.stride = stride
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=2, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels*2, kernel_size=1, stride=stride, padding=0)
        self.bn2 = nn.BatchNorm2d(in_channels*2)
        self.conv3 = nn.Conv2d(in_channels*2,in_channels*2, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(in_channels*2)
        
        self.match_conv = nn.Conv2d(in_channels, in_channels*2, kernel_size=2, stride=stride, padding=0)
    
    def match_input(self, x):
        x = self.match_conv(x)
        return x
    
    def forward(self, x):
        block_input = x
        # First Convolution
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        # Second Convolution
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        # Third Convolution
        x = self.bn3(self.conv3(x))
        block_input = self.match_input(block_input)
        # Add residual
        x += block_input
        x = F.leaky_relu(x)
        return x

In [5]:
class Encoder(nn.Module):
    def __init__(self, in_channels, embedding_dim):
        super(Encoder, self).__init__()
        # Residual Feature Extraction
        self.block_1 = ResidualBlock(in_channels, stride=1)
        self.pool_1 = nn.AvgPool2d(2)
        self.block_2 = ResidualBlock(in_channels*2, stride=2)
        self.pool_2 = nn.AvgPool2d(2)
        self.block_3 = ResidualBlock(in_channels*4, stride=2)
        self.pool_2 = nn.AvgPool2d(2)
        
        self.flatten_shape = None
        if self.flatten_shape is None:
            with torch.no_grad():
                zer = torch.zeros(size=(1,4,33,33))
                zer = self.convolutions(zer)
                self.flatten_shape = zer.shape[1]
                
        self.fc1 = nn.Linear(self.flatten_shape, 512)
        self.fc2 = nn.Linear(512, embedding_dim)
        
    def convolutions(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = x.view(x.shape[0], -1)
        return x

    def forward(self, x):
        x = self.convolutions(x)
        x = F.leaky_relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


In [6]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 512)
        self.fc2 = nn.Linear(512, 32*8*8)
        self.convt1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2, padding=0)
        self.bn1 = nn.BatchNorm2d(16, momentum=0.01)
        self.convt2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=2, stride=2, padding=0)
        self.bn2 = nn.BatchNorm2d(8, momentum=0.01)
        self.convt3 = nn.ConvTranspose2d(in_channels=8, out_channels=4, kernel_size=2, stride=2, padding=0)
        self.bn3 = nn.BatchNorm2d(4, momentum=0.01)
        
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x)).view(-1, 32, 8, 8)
        x = F.leaky_relu(self.bn1(self.convt1(x)))
        x = F.leaky_relu(self.bn2(self.convt2(x)))
        x = torch.sigmoid(self.bn3(self.convt3(x)))
        x = F.interpolate(x, size=(33,33))
        return x

In [7]:
class AutoEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(in_channels=4,embedding_dim=embedding_dim)
        self.decoder = Decoder(embedding_dim=embedding_dim)
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    def save_encoder_decoder(self):
        torch.save(self.encoder, "encoder.pt")
        torch.save(self.decoder, "decoder.pt")
    
        

In [None]:
EPOCHS = 30
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1024

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=6)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=6)

model = AutoEncoder(embedding_dim=16).to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
test_loss = []
train_loss = []

lst_mean = lst_mean.to(DEVICE)
lst_std = lst_std.to(DEVICE)
forcing_mean = forcing_mean.to(DEVICE)
forcing_std = forcing_std.to(DEVICE)

def process_data(image):
    image = image.to(torch.float32).to(DEVICE)
    image = torch.clip(image, min=0, max=1)
    return image

min_test_loss = np.inf    
for epoch in range(EPOCHS):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 6)))
    running_train_loss = 0
    train_n_iter = 0
    running_test_loss = 0
    test_n_iter = 0
    
    loop_train = tqdm(train_loader, total=(training_samples_len//BATCH_SIZE) + 1, leave=True)
    for idx, (image, _, _) in enumerate(loop_train):
        image = process_data(image)
        image_copy = image
        optimizer.zero_grad()
        forward_out = model.forward(image)
        loss = loss_fn(forward_out, image_copy)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        train_n_iter += 1
        loop_train.set_postfix(train_loss=loss.item())
        
    loop_test = tqdm(test_loader, total=(testing_samples_len//BATCH_SIZE) + 1, leave=False)
    
    with torch.no_grad():
        for idx, (image, _, _) in enumerate(loop_test):
            image = process_data(image)
            image_copy = image
            pred = model.forward(image)
            testloss = loss_fn(pred, image_copy)
            running_test_loss += testloss.item()
            test_n_iter += 1
            loop_test.set_postfix(test_loss=testloss.item())

    avg_train_loss = running_train_loss/train_n_iter
    train_loss.append(avg_train_loss)
    avg_test_loss = running_test_loss/test_n_iter
    test_loss.append(avg_test_loss)
    
    scheduler.step()
    scheduler2.step(avg_test_loss)
    if avg_test_loss < min_test_loss:
        print("Saving Model")
        min_test_loss = avg_test_loss
        model.module.save_encoder_decoder()
    print("------ Train Loss: {}, Test Loss: {} ------".format(avg_train_loss, avg_test_loss))

****** EPOCH: [0/30] LR: 0.001 ******


 11%|█████████▊                                                                                | 186/1710 [00:30<03:21,  7.55it/s, train_loss=0.00545]

In [None]:
import matplotlib.pyplot as plt
# plot lines
plt.plot(list(range(0,66)), train_loss, label = "train_loss")
plt.plot(list(range(0,66)), test_loss, label = "test_loss")
plt.legend()
plt.title("Training Error LST Prediction")
plt.savefig("Training Curve.png")
plt.show()

In [None]:
embedding_encoder = torch.load("encoder.pt").to("cpu")

In [None]:
class LSTModel(nn.Module):
    def __init__(self, forcing_shape=(1,7)):
        super(LSTModel, self).__init__()
        self.fc1 = nn.Linear(in_features=16+7, out_features=128)
        self.fc_bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0)
        self.fc2 = nn.Linear(in_features=128, out_features=64)
        self.fc_bn2 = nn.BatchNorm1d(64)
        self.drop2 = nn.Dropout(0)
        self.fc3 = nn.Linear(in_features=64, out_features=1)

    def forward(self, x, forcing):
        x = torch.cat((x, forcing), dim=1)
        x = F.leaky_relu(self.fc_bn1(self.fc1(x)))
        x = self.drop1(x)
        x = F.leaky_relu(self.fc_bn2(self.fc2(x)))
        x = self.drop2(x)
        x = self.fc3(x)
        
        return x

In [None]:
EPOCHS = 50
LEARNING_RATE = 0.008
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1024

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=6)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=6)

model = LSTModel().to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])
loss_fn = nn.SmoothL1Loss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.95)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.0005)
test_loss = []
train_loss = []

lst_mean = lst_mean.to(DEVICE)
lst_std = lst_std.to(DEVICE)
forcing_mean = forcing_mean.to(DEVICE)
forcing_std = forcing_std.to(DEVICE)

def process_data(image, forcing, lst):
    image, forcing, lst = image.to(torch.float32), forcing.to(DEVICE), lst.to(DEVICE)
    # Image Transformations
    image = torch.clip(image, min=0, max=1)
    image = embedding_encoder(image).to(DEVICE)
    # Forcing Transformation
    forcing = torch.div(torch.sub(forcing, forcing_mean), forcing_std).to(torch.float32)
    # LST Transformation
    lst = lst.view(-1, 1).to(torch.float32)
    return image, forcing, lst

min_test_loss = np.inf    
for epoch in range(EPOCHS):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 6)))
    running_train_loss = 0
    train_n_iter = 0
    running_test_loss = 0
    test_n_iter = 0
    
    loop_train = tqdm(train_loader, total=(training_samples_len//BATCH_SIZE) + 1, leave=True)
    for idx, (image, forcing, lst) in enumerate(loop_train):
        image, forcing, lst = process_data(image, forcing, lst)
        model.forward(image, forcing)
        optimizer.zero_grad()
        forward_out = model.forward(image, forcing)
        loss = loss_fn(forward_out, lst)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        train_n_iter += 1
        loop_train.set_postfix(train_loss=loss.item())
        
    loop_test = tqdm(test_loader, total=(testing_samples_len//BATCH_SIZE) + 1, leave=False)
    
    with torch.no_grad():
        for idx, (image, forcing, lst) in enumerate(loop_test):
            image, forcing, lst = process_data(image, forcing, lst)
            pred = model.forward(image, forcing)
            testloss = loss_fn(pred, lst)
            running_test_loss += testloss.item()
            test_n_iter += 1
            loop_test.set_postfix(test_loss=testloss.item())

    avg_train_loss = running_train_loss/train_n_iter
    train_loss.append(avg_train_loss)
    avg_test_loss = running_test_loss/test_n_iter
    test_loss.append(avg_test_loss)
    
    scheduler.step()
    scheduler2.step(avg_test_loss)
    if avg_test_loss < min_test_loss:
        print("Saving Model")
        min_test_loss = avg_test_loss
        torch.save(model.state_dict(), "lstmodel.pt")
    print("------ Train Loss: {}, Test Loss: {} ------".format(avg_train_loss, avg_test_loss))