In [120]:
import os, sys
import pandas as pd
import numpy as np
from tqdm import tqdm
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F


#add ../ to path
sys.path.append(os.path.join(os.path.dirname("__file__"), '../'))

from src.data_manager import (
    init_data,
    make_transforms
)
from src.utils import init_distributed
import src.msn_train as msn

In [121]:
with open('configs/temp.yaml', 'r') as y_file:
    args = yaml.load(y_file, Loader=yaml.FullLoader)

In [122]:
# -- META
world_size, rank = 1,1
model_name = args['meta']['model_name']
two_layer = False if 'two_layer' not in args['meta'] else args['meta']['two_layer']
bottleneck = 1 if 'bottleneck' not in args['meta'] else args['meta']['bottleneck']
output_dim = args['meta']['output_dim']
hidden_dim = args['meta']['hidden_dim']
load_model = args['meta']['load_checkpoint']
r_file = args['meta']['read_checkpoint']
# copy_data = args['meta']['copy_data']
use_pred_head = args['meta']['use_pred_head']
use_bn = args['meta']['use_bn']
drop_path_rate = args['meta']['drop_path_rate']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- CRITERTION
memax_weight = 1 if 'memax_weight' not in args['criterion'] else args['criterion']['memax_weight']
ent_weight = 1 if 'ent_weight' not in args['criterion'] else args['criterion']['ent_weight']
freeze_proto = False if 'freeze_proto' not in args['criterion'] else args['criterion']['freeze_proto']
use_ent = False if 'use_ent' not in args['criterion'] else args['criterion']['use_ent']
reg = args['criterion']['me_max']
use_sinkhorn = args['criterion']['use_sinkhorn']
num_proto = args['criterion']['num_proto']
# --
# batch_size = args['criterion']['batch_size']
batch_size = 64
temperature = args['criterion']['temperature']
_start_T = args['criterion']['start_sharpen']
_final_T = args['criterion']['final_sharpen']

# -- DATA
label_smoothing = args['data']['label_smoothing']
pin_mem = False if 'pin_mem' not in args['data'] else args['data']['pin_mem']
num_workers = 1 if 'num_workers' not in args['data'] else args['data']['num_workers']
norm_means = args['data']['norm_means']
norm_stds = args['data']['norm_stds']
# root_path = args['data']['root_path']
# image_folder = args['data']['image_folder']
patch_drop = args['data']['patch_drop']
rand_size = args['data']['rand_size']
rand_views = args['data']['rand_views']
focal_views = args['data']['focal_views']
focal_size = args['data']['focal_size']
surf_vars = args['data']['surf_vars']
static_vars = args['data']['static_vars']
lat_lim = args['data']['lat_limit']
lon_lim = args['data']['lon_limit']
split_val = args['data']['split_val']

# --

# -- OPTIMIZATION
clip_grad = args['optimization']['clip_grad']
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
# num_epochs = args['optimization']['epochs']
num_epochs = 100
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']

In [None]:
from src.data_manager import BrazilWeatherDataset
import torch
import torchvision.transforms as transforms

normalize = transforms.Normalize(
            norm_means,
            norm_stds)


dataset = BrazilWeatherDataset( transform=normalize,
                                surf_vars=surf_vars,
                                static_vars=static_vars,
                                lat_lim=lat_lim, lon_lim=lon_lim,
                                adj_prep_balance=False,
                                split_val=True,
                                return_time_period=True)


data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    pin_memory=pin_mem,
    num_workers=num_workers,
    shuffle=True,)


In [None]:
# https://www.climate.gov/news-features/understanding-climate/climate-variability-oceanic-nino-index
oni_index = pd.read_excel('data/oni_index.xlsx')
oni_index.set_index('Year', inplace=True)
oni_index.columns.name='Month'
original_col = list( oni_index.columns)
new_cols = [i for i in range(2,13)]+[1] #ultimo mes da média móvel
oni_index.columns = new_cols
oni_index= oni_index.unstack().to_frame('ONI')

oni_index['date_period'] = pd.to_datetime(oni_index.index.map(lambda x: f"{x[1]}-{x[0]}-01")).to_period('M')
oni_index['date_period'] = oni_index['date_period'].apply(lambda x: x.ordinal)

oni_index.set_index('date_period', inplace=True)

In [None]:
model_name = 'deit_small_temperature'
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"

encoder = msn.init_model(device=device,model_name=model_name)

INFO:root:VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(2, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1-11): 11 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True

In [None]:

class LabelDifference(nn.Module):
    def __init__(self, distance_type='l1'):
        super(LabelDifference, self).__init__()
        self.distance_type = distance_type

    def forward(self, labels):
        # labels: [bs, label_dim]
        # output: [bs, bs]
        if self.distance_type == 'l1':
            return torch.abs(labels[:, None, :] - labels[None, :, :]).sum(dim=-1)
        else:
            raise ValueError(self.distance_type)


class FeatureSimilarity(nn.Module):
    def __init__(self, similarity_type='l2'):
        super(FeatureSimilarity, self).__init__()
        self.similarity_type = similarity_type

    def forward(self, features):
        # labels: [bs, feat_dim]
        # output: [bs, bs]
        if self.similarity_type == 'l2':
            return - (features[:, None, :] - features[None, :, :]).norm(2, dim=-1)
        else:
            raise ValueError(self.similarity_type)


class RnCLoss(nn.Module):
    def __init__(self, temperature=2, label_diff='l1', feature_sim='l2'):
        super(RnCLoss, self).__init__()
        self.t = temperature
        self.label_diff_fn = LabelDifference(label_diff)
        self.feature_sim_fn = FeatureSimilarity(feature_sim)

    def forward(self, features, labels):
        # features: [bs, 2, feat_dim]
        # labels: [bs, label_dim]

        # features = torch.cat([features[:, 0], features[:, 1]], dim=0)  # [2bs, feat_dim]
        # labels = labels.repeat(2, 1)  # [2bs, label_dim]

        label_diffs = self.label_diff_fn(labels)
        logits = self.feature_sim_fn(features).div(self.t)
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits -= logits_max.detach()
        exp_logits = logits.exp()

        n = logits.shape[0]  # n = 2bs

        # remove diagonal
        logits = logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
        exp_logits = exp_logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
        label_diffs = label_diffs.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)

        loss = 0.
        for k in range(n - 1):
            pos_logits = logits[:, k]  # 2bs
            pos_label_diffs = label_diffs[:, k]  # 2bs
            neg_mask = (label_diffs >= pos_label_diffs.view(-1, 1)).float()  # [2bs, 2bs - 1]
            pos_log_probs = pos_logits - torch.log((neg_mask * exp_logits).sum(dim=-1))  # 2bs
            loss += - (pos_log_probs / (n * (n - 1))).sum()

        return loss

In [None]:
criterion = RnCLoss(temperature=2, label_diff='l1', feature_sim='l2')
optimizer = torch.optim.AdamW(encoder.parameters(), lr=lr, weight_decay=wd)

In [None]:
for epoch in range(0, num_epochs):

    for itr, (udata, utime) in enumerate(data_loader):


        optimizer.zero_grad()

        _, z = encoder(udata.to(device), return_before_head=True, patch_drop=patch_drop)
        anchor_views =  z.float()
        labels = torch.tensor(oni_index.loc[utime.numpy()]['ONI'].values).unsqueeze(-1).to(device)

        loss = criterion(anchor_views, labels)

        loss.backward()
        optimizer.step()
        
    if epoch % 10 == 0:
        print(f"Epoch [ {epoch}/{num_epochs}] Iter [{itr}/{len(data_loader)}] Loss: {loss.item():.4f}")

Epoch [0/100] Iter [1397/1398] Loss: 0.9593
Epoch [10/100] Iter [1397/1398] Loss: 0.9825


KeyboardInterrupt: 