In [None]:
import h5py
import numpy as np
import torch.utils.data as data
import torch.nn as nn
import torch
import os
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchvision.transforms import v2

In [None]:
class MyRandomRotation:
    def __call__(self, x):
        angle = np.random.choice([0,90,180,270])
        return v2.functional.rotate(x, angle)

train_transform = v2.Compose([v2.RandomHorizontalFlip(p=0.5),MyRandomRotation()])

TEST_TRANSFORM_LIST = [
    v2.RandomRotation((0,0)),
    v2.RandomRotation((90,90)),
    v2.RandomRotation((180,180)),
    v2.RandomRotation((270,270)),
    v2.RandomHorizontalFlip(p=1),
    v2.Compose([v2.RandomHorizontalFlip(p=1),v2.RandomRotation((90,90))]),
    v2.Compose([v2.RandomHorizontalFlip(p=1),v2.RandomRotation((180,180))]),
    v2.Compose([v2.RandomHorizontalFlip(p=1),v2.RandomRotation((270,270))]),
]
TEST_TRANSFORM_LIST_REVERSE = [
    v2.RandomRotation((0,0)),
    v2.RandomRotation((270,270)),
    v2.RandomRotation((180,180)),
    v2.RandomRotation((90,90)),
    v2.RandomHorizontalFlip(p=1),
    v2.Compose([v2.RandomRotation((270,270)),v2.RandomHorizontalFlip(p=1)]),
    v2.Compose([v2.RandomRotation((180,180)),v2.RandomHorizontalFlip(p=1)]),
    v2.Compose([v2.RandomRotation((90,90)),v2.RandomHorizontalFlip(p=1)]),
]

In [None]:
class RadarDataset(data.Dataset):

    def __init__(self, folder,files,in_seq_len=4, out_seq_len=12, mode='overlap', with_time=False, full_res=True,transform=None):
        self.in_seq_len = in_seq_len
        self.out_seq_len = out_seq_len
        self.seq_len = in_seq_len + out_seq_len
        self.with_time = with_time
        self.events = dict()
        self.intensity = dict()
        self.reflectivity1 = dict()
        self.reflectivity3 = dict()
        self.channels_per_timestamp=4
        self.full_res=full_res
        self.transform = transform
        self.__prepare_dicts(folder,files)
        self.__prepare_sequences(mode)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        x = []
        for timestamp in self.sequences[index][:self.in_seq_len]:
            x.append(self.intensity[timestamp])
            x.append(self.events[timestamp])
            x.append(self.reflectivity1[timestamp])
            x.append(self.reflectivity3[timestamp])
        x = torch.stack(x).type(torch.float32)
        
        y = []
        for timestamp in self.sequences[index][self.in_seq_len:]:
            y.append(self.intensity[timestamp])
        if (len(y)==0):
            y=torch.tensor([]).reshape(0,x.shape[-1],x.shape[-1])
        else:
            y = torch.stack(y).type(torch.float32)
        if self.transform is not None:
            t = torch.vstack((x,y))
            t = self.transform(t)
            x,y=t[:self.in_seq_len*self.channels_per_timestamp],t[self.in_seq_len*self.channels_per_timestamp:]
        if self.with_time:
            return (x, self.sequences[index][-1]), y
        else:
            return x, y

    def prepare_array(self,temp,full_res):
        temp[temp == -1e6] = 0
        temp[temp == -2e6] = -0.001
        temp[0,:]=-0.001
        temp[:,-1]=-0.001
        temp=torch.tensor(temp.astype(np.float16))
        temp=F.pad(temp,(2,2,2,2),value=-0.001)
        if(not full_res):
            temp = v2.functional.resize(temp.unsqueeze(0),(128,128),antialias=False).squeeze(0)
        return temp    
    
    def __prepare_dicts(self, folder,files):
        for file in files:
            with h5py.File(os.path.join(folder,file), mode='r') as d:
                timestamps = d.keys()
                for timestamp in tqdm(timestamps):
                    temp=np.array(d[timestamp]['intensity'],dtype=np.float32)
                    temp=self.prepare_array(temp,self.full_res)
                    self.intensity[int(timestamp)]=temp
                    
                    temp=np.array(d[timestamp]['events'],dtype=np.float32)
                    temp=self.prepare_array(temp,self.full_res)
                    self.events[int(timestamp)]=temp
                    
                    temp=np.array(d[timestamp]['reflectivity'][1],dtype=np.float32)
                    temp=self.prepare_array(temp,self.full_res)
                    self.reflectivity1[int(timestamp)]=temp
                    
                    temp=np.array(d[timestamp]['reflectivity'][3],dtype=np.float32)
                    temp=self.prepare_array(temp,self.full_res)
                    self.reflectivity3[int(timestamp)]=temp

    def __prepare_sequences(self, mode):
        timestamps = np.unique(sorted(self.intensity.keys()))
        if mode == 'sequentially':
            self.sequences = [
                timestamps[index * self.seq_len: (index + 1) * self.seq_len]
                for index in range(len(timestamps) // self.seq_len)
            ]
        elif mode == 'overlap':
            self.sequences = [
                timestamps[index: index + self.seq_len]
                for index in range(len(timestamps) - self.seq_len + 1)
            ]
        else:
            raise Exception(f'Unknown mode {mode}')
        self.sequences = list(filter(
            lambda x: int(x[-1]) - int(x[0]) == (self.seq_len - 1) * 600,
            self.sequences
        ))

In [None]:
folder="/kaggle/input/yandex-cup-ml-23-nowcasting/ML Cup 2023 Weather"
files=['2022-test-public.hdf5']
test_dataset=RadarDataset(folder,files,out_seq_len=0,mode='sequentially', with_time=True)

In [None]:
folder="/kaggle/input/yandex-cup-ml-23-nowcasting/ML Cup 2023 Weather/train"
files=os.listdir(folder)
train_dataset=RadarDataset(folder,files,mode='overlap',full_res=False, transform = train_transform)

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block= nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3,padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_ch, out_ch, 3,padding=1),
            nn.LeakyReLU(0.1)
        )
    
    def forward(self, x):
        return self.block(x)


class Encoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
            x = nn.Dropout(0.25)(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            x        = torch.cat([x, encoder_features[i]], dim=1)
            x = nn.Dropout(0.25)(x)
            x        = self.dec_blocks[i](x)
        return x


class UNet(nn.Module):
    def __init__(self, enc_chs, dec_chs, num_class):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 3,padding=1)

    def forward(self,x):
        if(x.shape[-1]==256):
            x=v2.functional.resize(x,(128,128),antialias=False)
            out=self._forward(x)
            out=v2.functional.resize(out,(256,256),antialias=False)
            return out
        else:
            return self._forward(x)
        
    def _forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.training:
            return out
        else:
            return nn.ReLU()(out)

In [None]:
def process_test(model, test_dataset,  output_file='output.hdf5', use_transforms=True):
    model.eval()
    device = next(model.parameters()).device
    for index, item in enumerate(test_dataset):
        (inputs, last_input_timestamp), _ = item
        inputs = inputs.unsqueeze(0).to(device)
        with torch.no_grad():
            if(use_transforms):
                outs=[]
                for i in range(8):
                    outs.append(TEST_TRANSFORM_LIST_REVERSE[i](model(TEST_TRANSFORM_LIST[i](inputs))))
                output=torch.stack(outs).mean(0)
            else:
                output = model(inputs)
        with h5py.File(output_file, mode='a') as f_out:
            for index in range(output.shape[1]):
                timestamp_out = str(int(last_input_timestamp) + 600 * (index + 1))
                f_out.create_group(timestamp_out)
                f_out[timestamp_out].create_dataset(
                    'intensity',
                    data=output[0, index][2:-2,2:-2].detach().cpu().numpy().astype(np.float16)
                )

In [None]:
def train_epoch(model, loader, optimizer, scheduler):
    model.train()
    device = next(model.parameters()).device
    for i,data in enumerate(loader):
        x,y = data
        x=x.to(device)
        y=y.to(device)
        mask = (y>=0).to(device)
        
        optimizer.zero_grad()
        pred = model(x)
        loss=torch.square((pred-y)*mask).sum((0,2,3))/y.shape[0]
        loss=torch.sqrt(loss).mean()
        loss.backward()
        optimizer.step()
        if (scheduler is not None):
            scheduler.step()

In [None]:
def set_seed(seed=451):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
set_seed(1204)

NUM_EPOCHS=45
LR=5e-4
BATCH_SIZE=8
PCT_START=0.3

train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

model=UNet(enc_chs=(16,32,64,128,256), dec_chs=(256,128,64,32), num_class=12).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR,pct_start=PCT_START,steps_per_epoch=len(train_loader), epochs=NUM_EPOCHS)

for epoch in tqdm(range(NUM_EPOCHS)):
    train_epoch(model, train_loader, optimizer, scheduler)
    process_test(model, test_dataset,  output_file=os.path.join('..','individual_predictions',f'FINAL-TWELVE-{epoch:02}.hdf5'), use_transforms=True)