In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import time

# Data

In [2]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor, Compose

import os
import random
import datetime
from copy import deepcopy

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from utils.dataset import CutOrPad, get_rgb

In [3]:
PASTIS24 = './data/PASTIS24/'
PASTIS9 = './data/PASTIS9/'

PATH = PASTIS24

In [4]:
files = os.listdir(PATH)
file = random.choice(files)

In [5]:
data = pd.read_pickle(PATH + file)

print(data.keys())
print('Image: ', data['img'].shape)
print('Labels: ', data['labels'].shape, data['labels'])
print('DOY: ', data['doy'].shape, data['doy'])

dict_keys(['img', 'labels', 'doy'])
Image:  (43, 10, 24, 24)
Labels:  (24, 24) [[ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  0.  0.  1.
   1.  1.  1.  1.  1.  0.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  0.  0.  0.  1.  1.  1.  1.
   1.  1.  1.  1.  0.  0.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.  0.  0.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  0.  0.  0.  0.  0.]
 [ 1.  1.  1.  1.  0.  0.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  0.  0.  0.  0.  0.]
 [ 0.  0.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  0.  0.  0.  0.  0.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
   1.  0.  0.  0.  0.  0.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.  1.  3.  3.  3.  1.  1.  1.  1.  1.  1.
   1.  1.  0.  0.  0.  1.]
 [ 1.  1.  1.  1.  1.  1.  3.  3.  3.  3.  3.  3.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.]
 [ 1.  1.  3.  3.  3.  3.  3.  3.  3.  3.  3.  3.  1.  1.  1.  1.  1.  1.
   1.  1.  1.  1.  1.  1.]
 [ 3.  3.  3

In [6]:
class PASTIS(Dataset):
    def __init__(self, pastis_path):
        self.pastis_path = pastis_path

        self.file_names = os.listdir(self.pastis_path)[:500]

        random.shuffle(self.file_names)

        self.to_cutorpad = CutOrPad()
        # self.to_tiledates = TileDates(24, 24)
        # self.to_unkmask = UnkMask(unk_class=19, ground_truth_target='labels'))



    def __len__(self):
        return len(self.file_names)


    def add_date_channel(self, img, doy):
        img = torch.cat((img, doy), dim=1)
        return img


    def normalize(self, img):
        C = img.shape[1]
        mean = img.mean(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)
        std = img.std(dim=(0, 2, 3)).to(torch.float32).reshape(1, C, 1, 1)

        img = (img - mean) / std

        return img


    def __getitem__(self, idx):
        data = pd.read_pickle(os.path.join(self.pastis_path, self.file_names[idx]))

        data['img'] = data['img'].astype('float32')
        data['img'] = torch.tensor(data['img'])
        data['img'] = self.normalize(data['img'])
        T, C, H, W = data['img'].shape

        data['labels'] = data['labels'].astype('long')
        data['labels'] = torch.tensor(data['labels'])
        # data['labels'] = F.one_hot(data['labels'].long(), num_classes=20)

        data['doy'] = data['doy'].astype('float32')
        data['doy'] = torch.tensor(data['doy'])
        data['doy'] = data['doy'].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        data['doy'] = data['doy'].repeat(1, 1, H, W)

        data['img'] = self.add_date_channel(data['img'], data['doy']) # add DOY to the last channel
        del data['doy'] # Delete DOY

        data = self.to_cutorpad(data) # Pad to Max Sequence Length
        del data['seq_lengths'] # Delete Sequence Length


        return data['img'], data['labels']

In [7]:
data = PASTIS(PATH)
data.__len__()

500

In [8]:
dataset = DataLoader(data, batch_size=4, shuffle=True)

In [None]:
img, label = next(iter(dataset))
plt.imshow(get_rgb(img[0][:,:-1,:,:].numpy()))
print(label[0].unique())

# Temporal-Spatial Vision Transformer

In [None]:
class Classification(nn.Module):
    def __init__(self, img_height=9, img_width=9, in_channel=10,
                       patch_size=3, embed_dim=512, max_time=60,
                       num_classes=20, num_head=4, dim_feedforward=2048,
                       num_layers=4
                ):
        super().__init__()
        
        self.H = img_height
        self.W = img_width
        self.P = patch_size
        self.C = in_channel
        self.d = embed_dim
        self.T = max_time
        self.K = num_classes

        self.d_model = self.d
        self.num_head = num_head
        self.dim_feedforward = dim_feedforward
        self.num_layers = num_layers

        self.N = int(self.H * self.W // self.P**2)
        # self.n = int(self.N**0.5)
        self.nh = int(self.H / self.P)
        self.nw = int(self.W / self.P)




        # Parameters
        self.encoderLayer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.num_head, dim_feedforward=self.dim_feedforward)
        self.encoder = nn.TransformerEncoder(self.encoderLayer, num_layers=self.num_layers)

        self.projection = nn.Conv3d(self.C, self.d, kernel_size=(1, self.P, self.P), stride=(1, self.P, self.P))


        self.temporal_emb = nn.Linear(366, self.d)
        self.temporal_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.temporal_transformer = self.encoder


        self.spatial_emb = nn.Parameter(torch.randn(1, self.N, self.d)) # (1, N, d)
        # self.spatial_cls_token = nn.Parameter(torch.randn(1, self.K, self.d)) # (1, K, d)
        self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, self.d)) # (1, 1, d)
        self.spatial_transformer = self.encoder
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.d), nn.Linear(self.d, 1))


    def forward(self, x):

        '''
        Tekenization
        '''
        # remove the timestamps (last channel) from the input
        x_sits = x[:, :, :-1]
        B, T, C, H, W = x_sits.shape # (B, T, C, H, W)
        
        x_sits = x_sits.reshape(B, C, T, H, W) # (B, C, T, H, W)
        x_sits = self.projection(x_sits) # (B, d, T, nw, nh)
        x_sits = x_sits.view(B, self.d, T, self.nh*self.nw) # (B, d, T, N)

        # Spatial Encoding (Positional Embeddings)
        # we dont add pos embedding here, cuz we need the pure data for the temporal encoder
        # x_sits = x_sits + self.pos_emb # (B, d, T, N) 

        x_sits = x_sits.permute(0,3,2,1) # (B, N, T, d)



        '''
        Temporal Encoding
        '''
        # in the last channel lies the timestamp
        xt = x[:, :, -1, 0, 0] # (B, T, C, H, W)
        # convert to one-hot
        xt = F.one_hot(xt.to(torch.int64), num_classes=366).to(torch.float32) # (B, T, 366)
        Pt = self.temporal_emb(xt) # (B, T, d)


        '''
        Temporal Encoder: cat(Z+Pt)
        '''
        x = x_sits + Pt.unsqueeze(1) # (B, N, T, d)

        # CLS Token
        temporal_cls_token = self.temporal_cls_token # (1, K, d)
        temporal_cls_token = temporal_cls_token.repeat(B, self.N, 1, 1) # (B, N, K, d)
        x = torch.cat([temporal_cls_token, x], dim=2) # (B, N, K+T, d)
        x = x.view(B*self.N, self.K + T, self.d) # (B*N, K+T, d)

        x = self.temporal_transformer(x) # (B*N, K+T, d)
        x = x.view(B, self.N, self.K + T, self.d) # (B, N, K+T, d)
        x = x[:,:,:self.K,:] # (B, N, K, d)
        x = x.permute(0, 2, 1, 3) # (B, K, N, d)



        '''
        Spatial Encoding
        '''
        Ps = self.spatial_emb # (1, N, d)
        Ps = Ps.unsqueeze(1) # (1, 1, N, d)
        x = x + Ps # (B, K, N, d)
        
        # CLS Token
        spatial_cls_token = self.spatial_cls_token # (1, 1, d)
        spatial_cls_token = spatial_cls_token.repeat(B, self.K, 1) # (B, K, d)
        spatial_cls_token = spatial_cls_token.unsqueeze(2) # (B, K, 1, d)

        x = torch.cat([spatial_cls_token, x], dim=2) # (B, K, N + 1, d)
    
        x = x.view(B*(self.N+1), self.K, self.d) # (B*(N+1), K, d)
        x = self.spatial_transformer(x) # (B*(N+1), K, d)
        x = x.view(B, self.N+1, self.K, self.d) # (B, (N+1), K, d)
        classes = x[:,0,:,:] # (B, K, d)
        # print(classes)
        x = self.mlp_head(classes) # (B, K, 1)
        # print(x, x.shape)
        x = x.reshape(B, self.K) # (B, K)
        
        return x

# SVM

In [None]:
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X = [x[0,:-1,:,:].reshape(10, 81) for x, y in data]
Y = [y for x, y in data]

X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

clf = svm.SVC(kernel='linear')
clf.fit(X_train, y_train)


# Training

## Loss Function

In [None]:
class FocalLoss(nn.Module):
    """
    Credits to  github.com/clcarwin/focal_loss_pytorch
    """
    def __init__(self, gamma=8, alpha=torch.ones(20), reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.reduction = reduction
        
    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            raise ValueError(
                "FocalLoss: reduction parameter not in list of acceptable values [\"mean\", \"sum\", None]")

## Train

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Data
batch_size = 2
dataset = DataLoader(data, batch_size=batch_size, shuffle=True)
num_samples = dataset.__len__()*batch_size

# Model
model = Segmentation(img_width=24, img_height=24, in_channel=10, patch_size=3, embed_dim=128, max_time=60, num_head=8, num_layers=8, num_classes=20)
model.to(device)

num_params = sum([p.numel() for p in model.parameters() if p.requires_grad == True])
print('Number of Parameters: ', num_params)

# Loss
# criterion = nn.CrossEntropyLoss()
# criterion = FocalLoss()
# criterion = nn.MSELoss()

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=5e-3, momentum=0.9)
# optimizer = optim.AdamW(model.parameters(), lr=0.001)
epochs = 100
model.train()

Number of Parameters:  1128329


Segmentation(
  (encoderLayer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=128, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): L

## Trainer

### Basic

In [13]:
from tqdm import tqdm
for epoch in range(epochs):
  epoch_loss = 0

  t1 = time.time()
  for batch in tqdm(dataset):
    img, label = batch
    img, label = img.to(device), label.to(device)


    optimizer.zero_grad()
    
    output = model(img)
    
    # print(f'Output shape: {output.shape} | Label shape: {label.shape}')
    # print('Output: ', output[0], 'Label: ', label[0])

    loss = criterion(output, label)
    epoch_loss += loss

    loss.backward()
    optimizer.step()


  if epoch % 10 == 0:
    torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss,
              }, f'./weights/epoch_{epoch}.pt')
  t2 = time.time()
  print('Epoch: ', epoch, 'Loss: ', (epoch_loss/num_samples)*100)

100%|██████████| 250/250 [00:42<00:00,  5.90it/s]


Epoch:  0 Loss:  tensor(150.0118, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 250/250 [00:43<00:00,  5.78it/s]


Epoch:  1 Loss:  tensor(149.5231, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 250/250 [00:44<00:00,  5.67it/s]


Epoch:  2 Loss:  tensor(149.3623, device='cuda:0', grad_fn=<MulBackward0>)


100%|██████████| 250/250 [00:44<00:00,  5.58it/s]


Epoch:  3 Loss:  tensor(149.1327, device='cuda:0', grad_fn=<MulBackward0>)


 31%|███       | 78/250 [00:14<00:31,  5.53it/s]


KeyboardInterrupt: 

In [19]:
torch.cuda.is_available()

True

### Ignite

In [None]:
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine


trainer = create_supervised_trainer(model, optimizer, criterion, device)

val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}

train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

In [None]:
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    print(batch,)
    x, y = batch[0].to(device), batch[0].to(device)
    print(x.shape, y.shape)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

trainer = Engine(train_step)

def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = model(x)
        return y_pred, y

train_evaluator = Engine(validation_step)
val_evaluator = Engine(validation_step)

# Attach metrics to the evaluators
for name, metric in val_metrics.items():
    metric.attach(train_evaluator, name)

for name, metric in val_metrics.items():
    metric.attach(val_evaluator, name)


In [None]:
log_interval = 100

In [None]:
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
    print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")

In [None]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    train_evaluator.run(dataset)
    metrics = train_evaluator.state.metrics
    print(f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")


In [None]:
# Score function to return current value of any metric we defined above in val_metrics
def score_function(engine):
    return engine.state.metrics["accuracy"]

# Checkpoint to store n_saved best models wrt score function
model_checkpoint = ModelCheckpoint(
    "checkpoint",
    n_saved=2,
    filename_prefix="best",
    score_function=score_function,
    score_name="accuracy",
    global_step_transform=global_step_from_engine(trainer), # helps fetch the trainer's state
)
  
# Save the model after every epoch of val_evaluator is completed
val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

In [None]:
# Define a Tensorboard logger
tb_logger = TensorboardLogger(log_dir="tb-logger")

# Attach handler to plot trainer's loss every 100 iterations
tb_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED(every=100),
    tag="training",
    output_transform=lambda loss: {"batch_loss": loss},
)

# Attach handler for plotting both evaluators' metrics after every epoch completes
for tag, evaluator in [("training", train_evaluator), ("validation", val_evaluator)]:
    tb_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag=tag,
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer),
    )

In [None]:
trainer.run(dataset, max_epochs=5)

# Evaluation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Classification(img_width=9, img_height=9, in_channel=10, patch_size=3, embed_dim=128, max_time=60)
model.to(device)

In [None]:
checkpoint = torch.load('weights/epoch_60.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
loss

# Playground

In [None]:
x, y = next(iter(dataset))
x = x.to(device)

In [None]:
torch.argmax(model(x), axis=3)