In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time

# Data

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor, Compose

import os
import random
import datetime
from copy import deepcopy

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from utils.dataset import CutOrPad, get_rgb

In [None]:
PASTIS9 = './data/PASTIS9/'
PATH = PASTIS9
files = os.listdir(PATH)

In [None]:
file = random.choice(files)
data = pd.read_pickle(PATH + file)

print(data.keys())
print('Image: ', data['img'].shape)
print('Labels: ', data['labels'].shape, data['labels'])
print('DOY: ', data['doy'].shape, data['doy'])

In [None]:
class PASTIS(Dataset):
    def __init__(self, pastis_path):
        self.pastis_path = pastis_path

        self.file_names = os.listdir(self.pastis_path)[:500]

        random.shuffle(self.file_names)

        self.to_cutorpad = CutOrPad()
        # self.to_tiledates = TileDates(24, 24)
        # self.to_unkmask = UnkMask(unk_class=19, ground_truth_target='labels'))



    def __len__(self):
        return len(self.file_names)


    def add_date_channel(self, img, doy):
        img = torch.cat((img, doy), dim=1)
        return img


    def normalize(self, img):
        C = img.shape[1]
        mean = img.mean(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)
        std = img.std(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)

        img = (img - mean) / std

        return img


    def __getitem__(self, idx):
        data = pd.read_pickle(os.path.join(self.pastis_path, self.file_names[idx]))

        data['img'] = data['img'].astype('float32')
        data['img'] = torch.tensor(data['img'])
        data['img'] = self.normalize(data['img'])
        T, C, H, W = data['img'].shape

        data['labels'] = data['labels'].astype('long')
        data['labels'] = torch.tensor(data['labels'])
        # data['labels'] = F.one_hot(data['labels'].long(), num_classes=20)

        data['doy'] = data['doy'].astype('float32')
        data['doy'] = torch.tensor(data['doy'])
        data['doy'] = data['doy'].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        data['doy'] = data['doy'].repeat(1, 1, H, W)

        data['img'] = self.add_date_channel(data['img'], data['doy']) # add DOY to the last channel
        del data['doy'] # Delete DOY

        data = self.to_cutorpad(data) # Pad to Max Sequence Length
        del data['seq_lengths'] # Delete Sequence Length


        return data['img'], data['labels']

In [None]:
data = PASTIS(PATH)
data.__len__()

In [None]:
dataset = DataLoader(data, batch_size=4, shuffle=True)

In [None]:
img, label = next(iter(dataset))
plt.imshow(get_rgb(img[0][:,:-1,:,:].numpy()))
print(label[0].unique())

# Network

In [None]:
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange



class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)



class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        # print(x.shape)
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
        # print(q.shape, k.shape, v.shape)
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out



class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.norm = nn.LayerNorm(dim)
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

In [None]:
class Classification(nn.Module):
    def __init__(self, img_height=9, img_width=9, in_channel=10,
                       patch_size=3, embed_dim=128, max_time=60,
                       num_classes=20, num_head=4, dim_feedforward=2048,
                       num_layers=4
                ):
        super().__init__()
        
        self.H = img_height
        self.W = img_width
        self.P = patch_size
        self.C = in_channel
        self.d = embed_dim
        self.T = max_time
        self.K = num_classes

        self.d_model = self.d
        self.num_head = num_head
        self.dim_feedforward = self.d
        self.num_layers = num_layers

        self.N = int(self.H * self.W // self.P**2)
        self.nh = int(self.H / self.P)
        self.nw = int(self.W / self.P)


        '''
        PARAMETERS
        '''
        # Transformer Encoder

        # PyTorch Encoder
        # self.encoderLayer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_head, dim_feedforward=self.dim_feedforward)
        # self.encoder = nn.TransformerEncoder(self.encoderLayer, num_layers=self.num_layers)

        # DeepSat Encoder
        self.encoder = Transformer(self.d, self.num_layers, self.num_head, 32, self.d*4)


        # torchvision Encoder
        # self.encoder = Encoder(seq_length=self.N, num_heads=4, num_layers=4, hidden_dim=self.d, mlp_dim=self.d*4, dropout=0., attention_dropout=0.)


        # Patches
        self.projection = nn.Conv3d(self.C, self.d, kernel_size=(1, self.P, self.P), stride=(1, self.P, self.P))
        '''
        def __init__():
            self.linear = nn.Linear(self.C*self.P**2, self.d)
        def forward():
            x = x.view(B, T, H // P, W // P, C*P**2)
            x = self.linear(x)
        '''

        # Temporal
        self.temporal_emb = nn.Linear(366, self.d)
        self.temporal_cls_token = nn.Parameter(torch.randn(1, self.N, self.K, self.d)) # (N, K, d)
        self.temporal_transformer = self.encoder

        # Spatial
        self.spatial_emb = nn.Parameter(torch.randn(1, self.N, self.d)) # (1, N, d)
        self.spatial_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.spatial_transformer = self.encoder

        # Segmentation Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.d),
            nn.Linear(self.d, 1)
            )



    def forward(self, x):
        '''
        Tekenization

        Convert the images to a sequence of patches
        '''
        x_sits = x[:, :, :-1, :, :] # (B, T, C, H, W) -- > Exclude DOY Channel
        B, T, C, H, W = x_sits.shape # (B, T, C, H, W)
        x_sits = x_sits.reshape(B, C, T, H, W) # (B, C, T, H, W)
        x_sits = self.projection(x_sits) # (B, d, T, nw, nh)
        x_sits = x_sits.reshape(B, self.d, T, self.nh*self.nw) # (B, d, T, N)
        # x_sits = x_sits + self.pos_emb # (B, d, T, N)  we dont add pos embedding here, cuz we need the pure data for the temporal encoder
        x_sits = x_sits.permute(0,3,2,1) # (B, N, T, d)



        '''
        Temporal Encoding

        (DOY -> One-Hot -> Projection)
        '''
        xt = x[:, :, -1, 0, 0] # (B, T, C, H, W) in the last channel lies the DOY feature
        xt = F.one_hot(xt.to(torch.int64), num_classes=366).to(torch.float32) # (B, T, 366)
        Pt = self.temporal_emb(xt) # (B, T, d) (DOY, one-hot encoded to represent the DOY feature and then encoded to d dimensions)




        '''
        Temporal Encoder: cat(Z+Pt)

        add temporal embeddings (N*K) to the Time Series patches (T)
        '''
        x = x_sits + Pt.unsqueeze(1) # (B, N, T, d)
        temporal_cls_token = self.temporal_cls_token # (1, N, K, d)
        temporal_cls_token = temporal_cls_token.repeat(B, 1, 1, 1) # (B, N, K, d)
        temporal_cls_token = temporal_cls_token.reshape(B*self.N, self.K, self.d) # (B*N, K, d)
        x = x.reshape(B*self.N, T, self.d) # (B*N, T, d)
        # Temporal Tokens (N*K)
        x = torch.cat([temporal_cls_token, x], dim=1) # (B*N, K+T, d)
        # Temporal Transformer
        x = self.temporal_transformer(x) # (B*N, K+T, d)
        x = x.reshape(B, self.N, self.K + T, self.d) # (B, N, K+T, d)
        x = x[:,:,:self.K,:] # (B, N, K, d)
        x = x.permute(0, 2, 1, 3) # (B, K, N, d)
        x = x.reshape(B*(self.K), self.N, self.d) # (B*K, N, d)




        '''
        Spatial Encoding
        '''
        Ps = self.spatial_emb # (1, N, d)
        x = x + Ps # (B*K, N, d)
        # For Classification Only
        spatial_cls_token = self.spatial_cls_token # (1, K, d)
        spatial_cls_token = spatial_cls_token.unsqueeze(2) # (1, K, 1, d)
        spatial_cls_token = spatial_cls_token.repeat(B, 1, 1, 1) # (B, K, 1, d)
        spatial_cls_token = spatial_cls_token.reshape(B*self.K, 1, self.d) # (B*K, 1, d)
        x = torch.cat([spatial_cls_token, x], dim=1) # (B*K, 1+N, d)
        x = self.spatial_transformer(x) # (B*K, N+1, d)



        '''
        Segmentation Head
        '''
        classes = x[:,0,:] # (B*K, d)
        classes = classes.reshape(B, self.K, self.d) # (B, K, d)
        
        x = self.mlp_head(classes) # (B, K, 1)
        x = x.reshape(B, self.K) # (B, K)



        return x

# Training

### Loss Function

In [None]:
class FocalLoss(nn.Module):
    """
    Credits to  github.com/clcarwin/focal_loss_pytorch
    """
    def __init__(self, gamma=8, alpha=torch.ones(20), reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.reduction = reduction
        
    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            raise ValueError(
                "FocalLoss: reduction parameter not in list of acceptable values [\"mean\", \"sum\", None]")

## Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Data
batch_size = 8
dataset = DataLoader(data, batch_size=batch_size, shuffle=True)
num_samples = dataset.__len__()*batch_size

# Model
model = Classification(img_width=9, img_height=9, in_channel=10, patch_size=3, embed_dim=128, max_time=60, num_head=8, num_layers=8, num_classes=20)
model.to(device)

num_params = sum([p.numel() for p in model.parameters() if p.requires_grad == True])
print('Number of Parameters: ', num_params)

# Loss
# criterion = nn.CrossEntropyLoss()
# criterion = FocalLoss()
# criterion = nn.MSELoss()

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=5e-3, momentum=0.9)
# optimizer = optim.AdamW(model.parameters(), lr=0.001)
epochs = 100
model.train()

## Trainer

### Basic

In [None]:
from tqdm import tqdm
for epoch in range(epochs):
  epoch_loss = 0

  t1 = time.time()
  for batch in tqdm(dataset):
    img, label = batch
    img, label = img.to(device), label.to(device)


    optimizer.zero_grad()
    
    output = model(img)
    
    # print(f'Output shape: {output.shape} | Label shape: {label.shape}')
    # print('Output: ', output[0], 'Label: ', label[0])

    loss = criterion(output, label)
    epoch_loss += loss

    loss.backward()
    optimizer.step()


  if epoch % 10 == 0:
    torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss,
              }, f'./weights/epoch_{epoch}.pt')
  t2 = time.time()
  print('Epoch: ', epoch, 'Loss: ', (epoch_loss/num_samples)*100)

# Evaluation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Classification(img_width=9, img_height=9, in_channel=10, patch_size=3, embed_dim=128, max_time=60)
model.to(device)

In [None]:
checkpoint = torch.load('weights/epoch_60.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# Playground

In [None]:
x, y = next(iter(dataset))
x = x.to(device)

In [None]:
torch.argmax(model(x), axis=3)