In [4]:
import numpy as np
import webdataset as wds
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms 
import os
import random

PATH_TO_DATA = "/media/priyammazumdar/M.2 Drive/M2 Data Storage/sharded_data/" 

In [5]:
image_normalize = transforms.Normalize(
                  mean=[0.17960437768666657, 0.14584139607643212, 0.10744440357398845, 0.2583671063835548],
                  std=[0.059635202669355195, 0.04059554002618016, 0.03371736326989986, 0.06295501902505744]
)

forcing_normalize = transforms.Normalize(
                  mean=[444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119],
                  std=[5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]
)

forcing_mean = torch.from_numpy(np.array([444.9605606256559, 991.7980623653417, 0.00039606951184754176, 96111.04161525163, 0.006652783216819315, 314.3219695851273, 2.82168247768119]))
forcing_std = torch.from_numpy(np.array([5.5216369223813535, 12.951212256256913, 0.0002824274832735609, 975.3770569179914, 0.00012386107613000674, 0.6004463118907452, 0.34279194598853185]))

lst_mean = torch.from_numpy(np.array([312.8291360088677]))
lst_std = torch.from_numpy(np.array([11.376636496297289]))

In [6]:
def create_train_test(path_to_data, train_perc, test_perc):
    files = []
    for dirpath, dirnames, filenames in os.walk(path_to_data):
        files.extend(filenames)
    
    saturated = files[:-1]
    unsaturated = files[-1]
    
    dataset = wds.WebDataset(path_to_data + "/" + unsaturated)
    counter = 0
    for data in dataset:
        counter += 1
    
    total_files = counter + len(saturated) * 10000
    training_data = total_files * train_perc //10000
    test_data_files = total_files * test_perc //10000

    training_data = random.sample(files, int(training_data))
    test_data = [file for file in files if file not in training_data]
    test_data = random.sample(test_data, int(test_data_files))
    # Get sample sizes of train and test data
    training_samples = 0
    testing_samples = 0
    
    for path in training_data:
        if path in saturated:
            training_samples += 10000
        elif path in unsaturated:
            training_samples += counter
            
    for path in test_data:
        if path in saturated:
            testing_samples += 10000
        elif path in unsaturated:
            testing_samples += counter
            
            
    # Convert to filename lists 
    training_filepath = []
    for dat in training_data:
        training_filepath.append(dat[6:12])
    training_path = path_to_data + "shard-" + "{" + ",".join(training_filepath) + "}" + ".tar"
    
    testing_filepath = []
    for dat in test_data:
        testing_filepath.append(dat[6:12])
    testing_path = path_to_data + "shard-{" + ",".join(testing_filepath) +"}.tar"
    train_data = wds.WebDataset(training_path).shuffle(30000, initial=30000).decode("rgb").rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
    test_data = wds.WebDataset(testing_path).decode("rgb").shuffle(30000, initial=30000).rename(image="image.pyd", forcing="forcing.pyd", lst = "lst.pyd").to_tuple("image", "forcing", "lst")
            
    return (train_data, training_samples), (test_data, testing_samples)
    
(train_data, training_samples_len), (test_data, testing_samples_len) = create_train_test(PATH_TO_DATA, 0.25, 0.1)


In [7]:
class LSTModel(nn.Module):
    def __init__(self, input_shape=(4,33,33), forcing_shape=(1,7)):
        super(LSTModel, self).__init__()
        self.in_channels = input_shape[0]
        self.input_shape = input_shape
        
        self.conv1 = nn.Conv2d(in_channels=self.in_channels, out_channels=8, kernel_size=(2,2))
        self.conv1_bn = nn.BatchNorm2d(8)
        self.mp1 = nn.MaxPool2d(kernel_size=(2,2), stride=1)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(2,2))
        self.conv2_bn = nn.BatchNorm2d(16)
        self.mp2 = nn.MaxPool2d(kernel_size=(2, 2), stride=1)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2, 2))
        self.conv3_bn = nn.BatchNorm2d(32)
        self.mp3 = nn.MaxPool2d(kernel_size=(2, 2), stride=1)

        self.flatten_shape = None
        zero_ex = torch.zeros(input_shape).unsqueeze(0)
        zero_forcing = torch.zeros(forcing_shape)
        
        with torch.no_grad():
            self.convolutions(zero_ex, zero_forcing)

        self.fc1 = nn.Linear(in_features=self.flatten_shape, out_features=512)
        self.drop1 = nn.Dropout(0.6)
        self.fc2 = nn.Linear(in_features=512, out_features=128)
        self.drop2 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(in_features=128, out_features=1)
        
    def convolutions(self, x, forcing):
        x = F.relu(self.conv1_bn(self.conv1(x)))
        x = self.mp1(x)
        x = F.relu(self.conv2_bn(self.conv2(x)))
        x = self.mp2(x)
        x = F.relu(self.conv3_bn(self.conv3(x)))
        x = self.mp3(x)

        # Reshape for linear
        x = x.view(x.shape[0], -1)
        x = torch.cat((x, forcing), dim=1)

        if self.flatten_shape is None:
            self.flatten_shape = x.shape[1]

        return x
    
    def forward(self, x, forcing):
        x = self.convolutions(x, forcing)
        x = F.relu(self.fc1(x))
        x = self.drop1(x)
        x = F.relu(self.fc2(x))
        x = self.drop2(x)
        x = self.fc3(x)
        
        return x
        

In [None]:
from tqdm import tqdm

EPOCHS = 250
LEARNING_RATE = 5e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 512

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=6)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=6)

model = LSTModel().to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])
loss_fn = nn.L1Loss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

test_loss = []
train_loss = []

lst_mean = lst_mean.to(DEVICE)
lst_std = lst_std.to(DEVICE)
forcing_mean = forcing_mean.to(DEVICE)
forcing_std = forcing_std.to(DEVICE)

def process_data(image, forcing, lst):
    image, forcing, lst = image.to(torch.float32).to(DEVICE), forcing.to(DEVICE), lst.to(DEVICE)
    # Image Transformations
    image = torch.clip(image, min=0)
    image = image_normalize(image)
    # Forcing Transformation
    forcing = torch.div(torch.sub(forcing, forcing_mean), forcing_std).to(torch.float32)
    # LST Transformation
#     lst = torch.div(torch.sub(lst, lst_mean), lst_std).to(torch.float32).view(-1, 1)
    lst = lst.view(-1, 1).to(torch.float32)
    return image, forcing, lst

    
for epoch in range(EPOCHS):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 6)))
    running_train_loss = 0
    train_n_iter = 0
    running_test_loss = 0
    test_n_iter = 0
    
    model.train()
    loop_train = tqdm(train_loader, total=(training_samples_len//BATCH_SIZE) + 1, leave=True)
    for idx, (image, forcing, lst) in enumerate(loop_train):
        image, forcing, lst = process_data(image, forcing, lst)
        optimizer.zero_grad()
        forward_out = model.forward(image, forcing)
        loss = loss_fn(forward_out, lst)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        train_n_iter += 1
        loop_train.set_postfix(train_loss=loss.item())
        
    loop_test = tqdm(test_loader, total=(testing_samples_len//BATCH_SIZE) + 1, leave=False)
    
    model.eval()
    with torch.no_grad():
        for idx, (image, forcing, lst) in enumerate(loop_test):
            image, forcing, lst = process_data(image, forcing, lst)
            pred = model.forward(image, forcing)
            testloss = loss_fn(pred, lst)
            running_test_loss += testloss.item()
            test_n_iter += 1
            loop_test.set_postfix(test_loss=testloss.item())

    avg_train_loss = running_train_loss/train_n_iter
    train_loss.append(avg_train_loss)
    avg_test_loss = running_test_loss/test_n_iter
    test_loss.append(avg_test_loss)
    
    print("------ Train Loss: {}, Test Loss: {} ------".format(avg_train_loss, avg_test_loss))
            
        
        
        
        

****** EPOCH: [0/250] LR: 5e-05 ******


999it [01:12, 13.85it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 45.318351279746544, Test Loss: 9.322657047794555 ------
****** EPOCH: [1/250] LR: 5e-05 ******


999it [01:05, 15.34it/s, train_loss=25.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 26.87176841491455, Test Loss: 8.278918167056165 ------
****** EPOCH: [2/250] LR: 5e-05 ******


999it [01:18, 12.76it/s, train_loss=24.9]                                                                                                             
                                                                                                                                                      

------ Train Loss: 25.871966412594844, Test Loss: 8.035880145687742 ------
****** EPOCH: [3/250] LR: 5e-05 ******


999it [01:14, 13.33it/s, train_loss=24.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 25.50633617779156, Test Loss: 8.504171430762044 ------
****** EPOCH: [4/250] LR: 5e-05 ******


999it [01:03, 15.61it/s, train_loss=25.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 25.26722044080824, Test Loss: 7.777043587060144 ------
****** EPOCH: [5/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=25.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 25.11783334466669, Test Loss: 7.91010191355865 ------
****** EPOCH: [6/250] LR: 5e-05 ******


999it [01:05, 15.23it/s, train_loss=25]                                                                                                               
                                                                                                                                                      

------ Train Loss: 25.061051528136414, Test Loss: 7.695773169473949 ------
****** EPOCH: [7/250] LR: 5e-05 ******


999it [01:04, 15.40it/s, train_loss=24.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.952385450865293, Test Loss: 7.500165901813411 ------
****** EPOCH: [8/250] LR: 5e-05 ******


999it [01:05, 15.28it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.90837676054961, Test Loss: 7.306423337326437 ------
****** EPOCH: [9/250] LR: 5e-05 ******


999it [01:05, 15.30it/s, train_loss=22.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.813305013769263, Test Loss: 7.803714809078856 ------
****** EPOCH: [10/250] LR: 5e-05 ******


999it [01:05, 15.36it/s, train_loss=27.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.793039839308303, Test Loss: 7.860170239724484 ------
****** EPOCH: [11/250] LR: 5e-05 ******


999it [01:05, 15.29it/s, train_loss=23.9]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.82188403499019, Test Loss: 7.419244789230037 ------
****** EPOCH: [12/250] LR: 5e-05 ******


999it [01:06, 15.03it/s, train_loss=25.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.744682594581885, Test Loss: 7.119758206575655 ------
****** EPOCH: [13/250] LR: 5e-05 ******


999it [01:05, 15.26it/s, train_loss=24]                                                                                                               
                                                                                                                                                      

------ Train Loss: 24.730883870396887, Test Loss: 8.081148583271782 ------
****** EPOCH: [14/250] LR: 5e-05 ******


999it [01:05, 15.15it/s, train_loss=23.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.7168826140441, Test Loss: 7.456955903677771 ------
****** EPOCH: [15/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=22.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.665567648184073, Test Loss: 7.222770358100155 ------
****** EPOCH: [16/250] LR: 5e-05 ******


999it [01:05, 15.32it/s, train_loss=23.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.620338306293355, Test Loss: 7.594702420500934 ------
****** EPOCH: [17/250] LR: 5e-05 ******


999it [01:05, 15.21it/s, train_loss=25.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.65278163256946, Test Loss: 7.059181609129543 ------
****** EPOCH: [18/250] LR: 5e-05 ******


999it [01:05, 15.19it/s, train_loss=24.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.65312418207392, Test Loss: 7.283801253071896 ------
****** EPOCH: [19/250] LR: 5e-05 ******


999it [01:05, 15.30it/s, train_loss=23.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.58201300131308, Test Loss: 7.105882824980063 ------
****** EPOCH: [20/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=24.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.589734194872975, Test Loss: 7.173270115392462 ------
****** EPOCH: [21/250] LR: 5e-05 ******


999it [01:05, 15.19it/s, train_loss=24.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.599985517897046, Test Loss: 7.612716425493889 ------
****** EPOCH: [22/250] LR: 5e-05 ******


999it [01:05, 15.36it/s, train_loss=23.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.57329170792191, Test Loss: 7.637055179189304 ------
****** EPOCH: [23/250] LR: 5e-05 ******


999it [01:04, 15.48it/s, train_loss=23.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.557777332233357, Test Loss: 7.629064866128912 ------
****** EPOCH: [24/250] LR: 5e-05 ******


999it [01:05, 15.18it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.5083162061445, Test Loss: 7.305119525357552 ------
****** EPOCH: [25/250] LR: 5e-05 ******


999it [01:04, 15.49it/s, train_loss=24.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.514933506886404, Test Loss: 7.618804053001598 ------
****** EPOCH: [26/250] LR: 5e-05 ******


999it [01:04, 15.46it/s, train_loss=24.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.502304016052186, Test Loss: 7.12027974903281 ------
****** EPOCH: [27/250] LR: 5e-05 ******


999it [01:05, 15.28it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.54121092298964, Test Loss: 7.317943282538864 ------
****** EPOCH: [28/250] LR: 5e-05 ******


999it [01:05, 15.35it/s, train_loss=24.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.523877538121617, Test Loss: 7.099380520999734 ------
****** EPOCH: [29/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=24.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.467887601575576, Test Loss: 7.2886177743146865 ------
****** EPOCH: [30/250] LR: 5e-05 ******


999it [01:04, 15.50it/s, train_loss=23]                                                                                                               
                                                                                                                                                      

------ Train Loss: 24.4431600942984, Test Loss: 6.831168877896924 ------
****** EPOCH: [31/250] LR: 5e-05 ******


999it [01:05, 15.32it/s, train_loss=23.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.463671152536815, Test Loss: 7.156536563398874 ------
****** EPOCH: [32/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.403725161089433, Test Loss: 7.45512178585614 ------
****** EPOCH: [33/250] LR: 5e-05 ******


999it [01:05, 15.36it/s, train_loss=23.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.347578860140658, Test Loss: 7.410221999066735 ------
****** EPOCH: [34/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.423823712705016, Test Loss: 7.127873546580978 ------
****** EPOCH: [35/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=25.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.387671901179743, Test Loss: 7.160876018746856 ------
****** EPOCH: [36/250] LR: 5e-05 ******


999it [01:04, 15.50it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.431199941549217, Test Loss: 6.989578623457003 ------
****** EPOCH: [37/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=24.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.37936162137174, Test Loss: 7.0619898181276275 ------
****** EPOCH: [38/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=23.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.41398406004882, Test Loss: 6.991316755410984 ------
****** EPOCH: [39/250] LR: 5e-05 ******


999it [01:05, 15.20it/s, train_loss=24.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.36984130928108, Test Loss: 6.844746341560093 ------
****** EPOCH: [40/250] LR: 5e-05 ******


999it [01:04, 15.45it/s, train_loss=24.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.33564909704932, Test Loss: 7.491229782249722 ------
****** EPOCH: [41/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=23.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.373037737291735, Test Loss: 7.485680209803702 ------
****** EPOCH: [42/250] LR: 5e-05 ******


999it [01:04, 15.45it/s, train_loss=24.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.35137935825535, Test Loss: 7.112134951625379 ------
****** EPOCH: [43/250] LR: 5e-05 ******


999it [01:04, 15.47it/s, train_loss=24.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.30520830545817, Test Loss: 6.9116015591597195 ------
****** EPOCH: [44/250] LR: 5e-05 ******


999it [01:04, 15.47it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.339457902345096, Test Loss: 7.093646313333269 ------
****** EPOCH: [45/250] LR: 5e-05 ******


999it [01:05, 15.25it/s, train_loss=23.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.354515814566398, Test Loss: 7.970499216602539 ------
****** EPOCH: [46/250] LR: 5e-05 ******


999it [01:05, 15.33it/s, train_loss=23.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.31557130097627, Test Loss: 7.565077981367934 ------
****** EPOCH: [47/250] LR: 5e-05 ******


999it [01:05, 15.35it/s, train_loss=24.9]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.307252780811208, Test Loss: 7.100840380954258 ------
****** EPOCH: [48/250] LR: 5e-05 ******


999it [01:04, 15.54it/s, train_loss=22.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.26309147587529, Test Loss: 7.547475736153308 ------
****** EPOCH: [49/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=25.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.28443506028917, Test Loss: 7.47606847976065 ------
****** EPOCH: [50/250] LR: 5e-05 ******


999it [01:05, 15.30it/s, train_loss=22.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.23119276350325, Test Loss: 6.880649803858723 ------
****** EPOCH: [51/250] LR: 5e-05 ******


999it [01:04, 15.50it/s, train_loss=22.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.268764698230946, Test Loss: 7.056071477493054 ------
****** EPOCH: [52/250] LR: 5e-05 ******


999it [01:03, 15.67it/s, train_loss=21.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.260763072872066, Test Loss: 7.1903189814030215 ------
****** EPOCH: [53/250] LR: 5e-05 ******


999it [01:05, 15.14it/s, train_loss=23.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.235320682162875, Test Loss: 7.08977087379107 ------
****** EPOCH: [54/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=25.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.240390558023233, Test Loss: 7.1777723803738045 ------
****** EPOCH: [55/250] LR: 5e-05 ******


999it [01:04, 15.44it/s, train_loss=24.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.227001012624562, Test Loss: 7.252327430066724 ------
****** EPOCH: [56/250] LR: 5e-05 ******


999it [01:04, 15.45it/s, train_loss=24.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.212619172440874, Test Loss: 7.190101911573846 ------
****** EPOCH: [57/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=25.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.256790227956838, Test Loss: 7.013001927264451 ------
****** EPOCH: [58/250] LR: 5e-05 ******


999it [01:05, 15.27it/s, train_loss=23.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.204195160049576, Test Loss: 6.901262170772262 ------
****** EPOCH: [59/250] LR: 5e-05 ******


999it [01:05, 15.29it/s, train_loss=24]                                                                                                               
                                                                                                                                                      

------ Train Loss: 24.196043652218503, Test Loss: 6.818858730006339 ------
****** EPOCH: [60/250] LR: 5e-05 ******


999it [01:05, 15.36it/s, train_loss=23.9]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.20047527844006, Test Loss: 7.177088722964834 ------
****** EPOCH: [61/250] LR: 5e-05 ******


999it [01:05, 15.34it/s, train_loss=24.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.155269023295755, Test Loss: 7.851035492069225 ------
****** EPOCH: [62/250] LR: 5e-05 ******


999it [01:05, 15.32it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.17573650654133, Test Loss: 7.644525342786372 ------
****** EPOCH: [63/250] LR: 5e-05 ******


999it [01:04, 15.39it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.144348541656893, Test Loss: 7.309517618363278 ------
****** EPOCH: [64/250] LR: 5e-05 ******


999it [01:05, 15.21it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.157809147724997, Test Loss: 6.884965675131318 ------
****** EPOCH: [65/250] LR: 5e-05 ******


999it [01:05, 15.26it/s, train_loss=23.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.138087622992867, Test Loss: 7.2274885177612305 ------
****** EPOCH: [66/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=24.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.13814081873622, Test Loss: 6.842704181138634 ------
****** EPOCH: [67/250] LR: 5e-05 ******


999it [01:04, 15.43it/s, train_loss=24.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.14779286007504, Test Loss: 6.953224936112535 ------
****** EPOCH: [68/250] LR: 5e-05 ******


999it [01:05, 15.32it/s, train_loss=23.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.10078826919571, Test Loss: 6.8414300710416684 ------
****** EPOCH: [69/250] LR: 5e-05 ******


999it [01:05, 15.19it/s, train_loss=24.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.130361604738283, Test Loss: 7.287650260828474 ------
****** EPOCH: [70/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=23.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.09889544428767, Test Loss: 6.885379646030175 ------
****** EPOCH: [71/250] LR: 5e-05 ******


999it [01:04, 15.48it/s, train_loss=23.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.128462635838353, Test Loss: 6.932541446637381 ------
****** EPOCH: [72/250] LR: 5e-05 ******


999it [01:04, 15.39it/s, train_loss=23.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.095558649546152, Test Loss: 7.309704580887925 ------
****** EPOCH: [73/250] LR: 5e-05 ******


999it [01:05, 15.26it/s, train_loss=24.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.076239222162837, Test Loss: 6.860566206995001 ------
****** EPOCH: [74/250] LR: 5e-05 ******


999it [01:05, 15.35it/s, train_loss=23.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.09683123866359, Test Loss: 7.288457341605636 ------
****** EPOCH: [75/250] LR: 5e-05 ******


999it [01:05, 15.34it/s, train_loss=23.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.075982123404533, Test Loss: 7.018926120651555 ------
****** EPOCH: [76/250] LR: 5e-05 ******


999it [01:04, 15.43it/s, train_loss=24.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.072885106633734, Test Loss: 6.943648993061279 ------
****** EPOCH: [77/250] LR: 5e-05 ******


999it [01:05, 15.29it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.01781629155706, Test Loss: 7.492800915906877 ------
****** EPOCH: [78/250] LR: 5e-05 ******


999it [01:05, 15.29it/s, train_loss=24.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 24.012928740278976, Test Loss: 7.239427052173518 ------
****** EPOCH: [79/250] LR: 5e-05 ******


999it [01:04, 15.48it/s, train_loss=23]                                                                                                               
                                                                                                                                                      

------ Train Loss: 24.00019879193158, Test Loss: 6.905796345115313 ------
****** EPOCH: [80/250] LR: 5e-05 ******


999it [01:04, 15.45it/s, train_loss=21.9]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.983308774930936, Test Loss: 7.163749579850792 ------
****** EPOCH: [81/250] LR: 5e-05 ******


999it [01:04, 15.42it/s, train_loss=23.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.95176760737483, Test Loss: 6.907194947228214 ------
****** EPOCH: [82/250] LR: 5e-05 ******


999it [01:05, 15.32it/s, train_loss=23.7]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.994441434308452, Test Loss: 6.942430740685633 ------
****** EPOCH: [83/250] LR: 5e-05 ******


999it [01:05, 15.30it/s, train_loss=24.4]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.966710977487498, Test Loss: 6.88993692156022 ------
****** EPOCH: [84/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=23]                                                                                                               
                                                                                                                                                      

------ Train Loss: 24.03480780661643, Test Loss: 7.388748581639401 ------
****** EPOCH: [85/250] LR: 5e-05 ******


999it [01:04, 15.38it/s, train_loss=24.6]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.960400950801265, Test Loss: 7.443277069759853 ------
****** EPOCH: [86/250] LR: 5e-05 ******


999it [01:05, 15.35it/s, train_loss=24.2]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.951122070098663, Test Loss: 6.8955787728885705 ------
****** EPOCH: [87/250] LR: 5e-05 ******


999it [01:04, 15.42it/s, train_loss=24.1]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.993955195964396, Test Loss: 7.077628481811678 ------
****** EPOCH: [88/250] LR: 5e-05 ******


999it [01:05, 15.31it/s, train_loss=22.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.935350532646293, Test Loss: 7.0564809208594 ------
****** EPOCH: [89/250] LR: 5e-05 ******


999it [01:05, 15.28it/s, train_loss=24.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.951860187289952, Test Loss: 7.036798828144364 ------
****** EPOCH: [90/250] LR: 5e-05 ******


999it [01:04, 15.44it/s, train_loss=23.5]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.88123230270676, Test Loss: 7.0346660117812565 ------
****** EPOCH: [91/250] LR: 5e-05 ******


999it [01:05, 15.27it/s, train_loss=23.3]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.93307379798011, Test Loss: 6.91552633803508 ------
****** EPOCH: [92/250] LR: 5e-05 ******


999it [01:04, 15.41it/s, train_loss=22.8]                                                                                                             
                                                                                                                                                      

------ Train Loss: 23.896517880566723, Test Loss: 6.794576525083048 ------
****** EPOCH: [93/250] LR: 5e-05 ******


999it [01:05, 15.33it/s, train_loss=23]                                                                                                               
                                                                                                                                                      

------ Train Loss: 23.930039847815955, Test Loss: 6.766834444200932 ------
****** EPOCH: [94/250] LR: 5e-05 ******


 47%|████████████████████████████████████████████▎                                                 | 471/998 [00:41<00:37, 14.01it/s, train_loss=22.8]