In [1]:
import numpy as np
import webdataset as wds
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms 
import torchvision.models as models
import os
import random

PATH_TO_DATA =  "/mnt/analysis/analysis/rand_sharded_data/" 

In [8]:
image_normalize = transforms.Normalize(
                  mean=[0, 0, 0, 0, 4.0999e+02],
                  std=[1, 1, 1, 1, 7.4237e+01]
)
#with clipping to 600: 4.0999e+02, 7.4237e+01
#without clipping: 4.1454e+02, 8.9422e+01

forcing_normalize = transforms.Normalize(
                  mean=[444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119],
                  std=[5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]
)

elevation_mean = torch.from_numpy(np.array([4.0999e+02]))
elevation_std = torch.from_numpy(np.array([7.4237e+01]))

forcing_mean = torch.from_numpy(np.array([444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119]))
forcing_std = torch.from_numpy(np.array([5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]))

lst_mean = torch.from_numpy(np.array([312.8291360088677]))
lst_std = torch.from_numpy(np.array([11.376636496297289]))

In [9]:
def create_train_test(path_to_data, train_perc, test_perc):
    files = []
    for dirpath, dirnames, filenames in os.walk(path_to_data):
        files.extend(filenames)
    
    saturated = files[:-1]
    unsaturated = files[-1]
    
    dataset = wds.WebDataset(path_to_data + "/" + unsaturated)
    counter = 0
    for data in dataset:
        counter += 1
    
    total_files = counter + len(saturated) * 10000
    training_data = total_files * train_perc //10000
    test_data_files = total_files * test_perc //10000

    training_data = random.sample(files, int(training_data))
    test_data = [file for file in files if file not in training_data]
    test_data = random.sample(test_data, int(test_data_files))
    # Get sample sizes of train and test data
    training_samples = 0
    testing_samples = 0
    
    for path in training_data:
        if path in saturated:
            training_samples += 10000
        elif path in unsaturated:
            training_samples += counter
            
    for path in test_data:
        if path in saturated:
            testing_samples += 10000
        elif path in unsaturated:
            testing_samples += counter
            
            
    # Convert to filename lists 
    training_filepath = []
    for dat in training_data:
        training_filepath.append(dat[6:12])
    training_path = path_to_data + "shard-" + "{" + ",".join(training_filepath) + "}" + ".tar"
    
    testing_filepath = []
    for dat in test_data:
        testing_filepath.append(dat[6:12])
    testing_path = path_to_data + "shard-{" + ",".join(testing_filepath) +"}.tar"
    train_data = wds.WebDataset(training_path).shuffle(10000, initial=10000).decode("rgb").rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
    test_data = wds.WebDataset(testing_path).decode("rgb").shuffle(10000, initial=10000).rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
            
    return (train_data, training_samples), (test_data, testing_samples)
    
(train_data, training_samples_len), (test_data, testing_samples_len) = create_train_test(PATH_TO_DATA, 0.85, 0.15)


In [10]:
for data in train_data:
    print(data[0].shape)
    print(data[1].shape)
    break

(5, 33, 33)
(7,)


In [11]:
class PatchEmbedding(nn.Module):
    """
    This function will split our image into patches
    """

    def __init__(self, image_size, patch_size, in_channels=3, embedding=768):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (int(image_size/patch_size))**2

        self.patching_conv = nn.Conv2d(in_channels=in_channels,
                                       out_channels=embedding,
                                       kernel_size=patch_size,
                                       stride=patch_size)

    def forward(self, x):
        # Convert single image to embedding x (root n x root n patches)
        x = self.patching_conv(x)

        # Flatten on second dimension to get embedding x n patches
        x = x.flatten(2)

        # Transpose to get n patches x embedding, thus giving each patch a 768 vector embedding
        x = x.transpose(1,2)
        return x

class Attention(nn.Module):
    """
    Build the attention mechanism (nearly identical to original Transformer Paper
    """
    def __init__(self, embedding, num_heads, qkv_b=True, attention_drop_p=0, projection_drop_p=0):
        super().__init__()
        self.embedding = embedding # Size of embedding vector
        self.num_heads = num_heads # Number of heads in multiheaded attention layers
        self.qkv_b = qkv_b # Do we want a bias term on our QKV linear layer
        self.attention_drop_p = attention_drop_p # Attention layer dropout probability
        self.projection_drop_p = projection_drop_p # Projection layer dropout probability
        self.head_dimension = int(self.embedding/self.num_heads) # Dimension of each head in multiheaded attention
        self.scaling = self.head_dimension ** 0.5 # Scaling recommended by original transformer paper for exploding grad


        self.qkv = nn.Linear(embedding, embedding * 3)
        self.attention_drop = nn.Dropout(self.attention_drop_p)
        self.projection = nn.Linear(embedding, embedding)
        self.projection_drop = nn.Dropout(self.projection_drop_p)

    def forward(self, x):
        # Get shape of input layer, samples x patches + patchembedding (1) x embedding
        samples, patches, embedding = x.shape # (samples, patches+1, embedding)

        # Expand embedding to 3 x embedding for QKV
        qkv = self.qkv(x) # (sample, patches+1, 3*embedding)

        # Reshape so that for every patch + 1 in every sample we have QKV with dimension number of heads by its dimension
        # Remember that num_heads * head_dimension = embedding
        qkv = qkv.reshape(samples, patches, 3, self.num_heads, self.head_dimension) # (samples, patches+1, 3, num_heads, head_dim)
        
        # Permute such that we have QKV so each has all samples, and each head in each sample
        # has dimensions patches + 1 by heads dimension
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, samples, heads, patches+1, head_dim)

        # Separate out QKV
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Transpose patches and head dimension of K
        transpose_k = k.transpose(-2, -1) # (samples, heads, head_dim, patches+1)

        # Matrix Multiplication of Q and K scaled
        # (samples, heads, patches+1, head_dim) (samples, heads, head_dim, patches + 1)
        # output: (sample, heads, patches+1, patches+1)
        scaled_mult = torch.matmul(q, transpose_k) / self.scaling

        # Run scaled multiplication through softmax layer along last dimension
        attention = scaled_mult.softmax(dim=-1)
        attention = self.attention_drop(attention)

        # Calculate weighted average along V
        # (sample, heads, patches+1, patches+1) x (samples, heads, patches+1, head_dim)
        # Output (sample, heads, patches+1, head_dim)
        weighted_average = torch.matmul(attention, v)

        # Transpose to (samples, patches+1, heads, head_dim)
        weighted_average = weighted_average.transpose(1,2)

        # Flatten on last layer to get back original shape of (sample, patches + 1, embedding)
        weighted_average = weighted_average.flatten(2)

        # Run through our projection layer with dropout
        x = self.projection_drop(self.projection(x))
        return x

class MultiLayerPerceptron(nn.Module):
    """
    Simple Multi Layer Perceptron with GELU activation
    """
    def __init__(self, input_features, hidden_features, output_features, dropout_p=0):
        super().__init__()
        self.fc1 = nn.Linear(input_features, hidden_features)
        self.drop_1 = nn.Dropout(dropout_p)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, output_features)
        self.drop_2 = nn.Dropout(dropout_p)

    def forward(self, x):
        x = self.drop_1(self.gelu(self.fc1(x)))
        x = self.drop_2(self.fc2(x))
        return x


class TransformerBlock(nn.Module):
    """
    Create Self Attention Block with alyer normalization
    """
    def __init__(self, embedding, num_heads, hidden_features=2048,qkv_b=True, attention_dropout_p=0,
                 projection_dropout_p=0, mlp_dropout_p=0):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(embedding, eps=1e-6)
        self.attention = Attention(embedding=embedding,
                                   num_heads=num_heads,
                                   qkv_b=qkv_b,
                                   attention_drop_p=attention_dropout_p,
                                   projection_drop_p=projection_dropout_p)
        self.layernorm2 = nn.LayerNorm(embedding, eps=1e-6)
        self.feedforward = MultiLayerPerceptron(input_features=embedding,
                                                hidden_features=hidden_features,
                                                output_features=embedding,
                                                dropout_p=mlp_dropout_p)

    def forward(self, x):
        x = x + self.attention(self.layernorm1(x))
        x = x + self.feedforward(self.layernorm2(x))
        return x

class VisionTransformer(nn.Module):
    """
    Putting together the Vision Transformer
    """
    def __init__(self, image_size=512, patch_size=16, in_channels=3, num_outputs=1000, embeddings=768,
                 num_blocks=12, num_heads=12, hidden_features=2048, qkv_b=True, attention_dropout_p=0,
                 projection_dropout_p=0, mlp_dropout_p=0, pos_embedding_dropout=0):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              in_channels=in_channels,
                                              embedding=embeddings)
        self.class_token = nn.Parameter(torch.zeros(size=(1,1,embeddings)))
        self.positional_embedding = nn.Parameter(torch.zeros(size=(1,1+self.patch_embedding.num_patches, embeddings)))
        self.positional_dropout = nn.Dropout(pos_embedding_dropout)
        self.transformer_block = TransformerBlock(embedding=embeddings,
                                                  num_heads=num_heads,
                                                  hidden_features=hidden_features,
                                                  qkv_b=qkv_b,
                                                  attention_dropout_p=attention_dropout_p,
                                                  projection_dropout_p=projection_dropout_p,
                                                  mlp_dropout_p=mlp_dropout_p)
        self.transformer_blocks = nn.ModuleList([
            self.transformer_block for _ in range(num_blocks)
        ])

        self.layernorm = nn.LayerNorm(embeddings, eps=1e-6)
        self.out = nn.Linear(embeddings + 7, num_outputs)

    def forward(self, x, forcing):
        num_samples = x.shape[0]
        x = self.patch_embedding(x)
        class_token = self.class_token.expand(num_samples, -1, -1)
        x = torch.cat((class_token, x), dim=1)
        x = x + self.positional_embedding
        x = self.positional_dropout(x)

        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)

        x = self.layernorm(x)
        output_class_token = x[:, 0]
        output_class_token = torch.cat((output_class_token, forcing), dim=1)
        x = self.out(output_class_token)
        return x

In [13]:
vit = VisionTransformer(image_size=33, patch_size=3, in_channels=5, num_outputs=1, embeddings=768,
                        num_blocks=6, num_heads=12, hidden_features=1024, qkv_b=True, attention_dropout_p=0.2,
                        projection_dropout_p=0.2, mlp_dropout_p=0.2, pos_embedding_dropout=0.2)

In [None]:
from tqdm import tqdm
torch.autograd.set_detect_anomaly(True)
EPOCHS = 50
LEARNING_RATE = 0.0005
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=6)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=6)

model = vit.to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])
loss_fn = nn.SmoothL1Loss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.95)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.0005)
test_loss = []
train_loss = []

lst_mean = lst_mean.to(DEVICE)
lst_std = lst_std.to(DEVICE)
forcing_mean = forcing_mean.to(DEVICE)
forcing_std = forcing_std.to(DEVICE)

def process_data(image, forcing, lst):
    image, forcing, lst = image.to(DEVICE).to(torch.float32), forcing.to(DEVICE), lst.to(DEVICE)
    image[:,4,] = torch.clip(image[:,4,], min=0, max=600)
    image[:,:4,] = torch.clip(image[:,:4,], min=0, max=1)
    image = image_normalize(image)
    forcing = torch.div(torch.sub(forcing, forcing_mean), forcing_std).to(torch.float32)
    # LST Transformation
#     lst = torch.div(torch.sub(lst, lst_mean), lst_std).to(torch.float32).view(-1, 1)
    lst = lst.view(-1, 1).to(torch.float32)
    return image, forcing, lst


min_test_loss = np.inf    
for epoch in range(EPOCHS):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 6)))
    running_train_loss = 0
    train_n_iter = 0
    running_test_loss = 0
    test_n_iter = 0
    
    loop_train = tqdm(train_loader, total=(training_samples_len//BATCH_SIZE) + 1, leave=True)
    for idx, (image, forcing, lst) in enumerate(loop_train):
        image, forcing, lst = process_data(image, forcing, lst)
        optimizer.zero_grad()
        forward_out = model.forward(image, forcing)
        loss = loss_fn(forward_out, lst)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        train_n_iter += 1
        loop_train.set_postfix(train_loss=loss.item())
        
    loop_test = tqdm(test_loader, total=(testing_samples_len//BATCH_SIZE) + 1, leave=False)
    
    with torch.no_grad():
        for idx, (image, forcing, lst) in enumerate(loop_test):
            image, forcing, lst = process_data(image, forcing, lst)
            pred = model.forward(image, forcing)
            testloss = loss_fn(pred, lst)
            running_test_loss += testloss.item()
            test_n_iter += 1
            loop_test.set_postfix(test_loss=testloss.item())

    avg_train_loss = running_train_loss/train_n_iter
    train_loss.append(avg_train_loss)
    avg_test_loss = running_test_loss/test_n_iter
    test_loss.append(avg_test_loss)
    
    scheduler.step()
    scheduler2.step(avg_test_loss)
    if avg_test_loss < min_test_loss:
        print("Saving Model")
        min_test_loss = avg_test_loss
        torch.save(model.state_dict(), "resnetmodel.pt")
    print("------ Train Loss: {}, Test Loss: {} ------".format(avg_train_loss, avg_test_loss))
            
        
        
        
        

****** EPOCH: [0/50] LR: 0.0005 ******


  3%|██▌                                                                                          | 187/6838 [00:54<29:27,  3.76it/s, train_loss=8.35]

In [None]:
import matplotlib.pyplot as plt
# plot lines
plt.plot(list(range(0,18)), train_loss, label = "train_loss")
plt.plot(list(range(0,18)), test_loss, label = "test_loss")
plt.legend()
plt.title("Training Error LST Prediction")
plt.savefig("Training Curve.png")
plt.ylim([0, 10])
plt.show()

In [None]:
model = torch.load(PATH)