# Prepare source domain dataloaders and experts

In [1]:
# Add to path
import sys
import pandas as pd
import torch
from torch import nn
import numpy as np
from torch import optim
import tqdm
LOAD_EXPERTS = True
DEVICE = 'cpu'

sys.path.append('../domain-adaptation')

In [2]:


df_train = pd.read_csv('/Users/nikglukhov/n.glukhov/canonical-paritioned-dataset/shifts_canonical_train.csv')
df_dev_in = pd.read_csv('/Users/nikglukhov/n.glukhov/canonical-paritioned-dataset/shifts_canonical_dev_in.csv')
df_dev_out = pd.read_csv('/Users/nikglukhov/n.glukhov/canonical-paritioned-dataset/shifts_canonical_dev_out.csv')
df_dev = pd.concat([df_dev_in, df_dev_out])

domains_train = df_train.climate.unique()

In [3]:
import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, climate):
        self.df = df
        if climate is not None:
            self.X_source_domain = df[df.climate == climate].iloc[:,6:].copy()
            self.y_source_domain = df[df.climate == climate]['fact_temperature'].copy()
            self.climate = climate
        else:
            self.X_source_domain = df.iloc[:,6:].copy()
            self.y_source_domain = df['fact_temperature'].copy()
            self.climate = climate

        assert len(self.X_source_domain) == len(self.y_source_domain)

    def __len__(self):
        return len(self.y_source_domain)

    def __getitem__(self, index):
        X = torch.tensor(self.X_source_domain.iloc[index].values).to(torch.float32)
        y = torch.tensor(self.y_source_domain.iloc[index]).to(torch.float32)
        metadata = {
            'climate': self.climate if self.climate is not None else self.df.iloc[index].climate
        }
        return X, y, metadata

### Create source domain loaders

In [4]:
batch_size = 64
source_train_domains_loaders =  {
    climate: torch.utils.data.DataLoader(Dataset(df_train, climate), batch_size = batch_size)
    for climate in domains_train
}
train_loader = torch.utils.data.DataLoader(Dataset(df_train, None), batch_size = batch_size)
val_loader = torch.utils.data.DataLoader(Dataset(df_dev, climate=None), batch_size = batch_size)

In [5]:
x, y, metadata = next(iter(train_loader))

In [11]:
metadata['climate']

['dry',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'dry',
 'mild temperate',
 'mild temperate',
 'tropical',
 'mild temperate',
 'dry',
 'mild temperate',
 'dry',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'tropical',
 'dry',
 'tropical',
 'dry',
 'dry',
 'mild temperate',
 'tropical',
 'mild temperate',
 'tropical',
 'tropical',
 'dry',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'dry',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'dry',
 'mild temperate',
 'mild temperate',
 'dry',
 'mild temperate',
 'dry',
 'dry',
 'dry',
 'mild temperate',
 'tropical',
 'dry',
 'dry',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'mild temperate',
 'tropical',
 'mild temperate',
 'dry',
 'tropical',
 'tropical']

### Train or Load experts

In [12]:
source_domains_experts = {}
def initializeExperts():

    model = nn.Sequential(nn.Linear(123, 1)).to(DEVICE)

    for climate in domains_train:
        source_domains_experts[climate] = model
initializeExperts()

In [13]:
source_domains_experts

{'dry': Sequential(
   (0): Linear(in_features=123, out_features=1, bias=True)
 ),
 'mild temperate': Sequential(
   (0): Linear(in_features=123, out_features=1, bias=True)
 ),
 'tropical': Sequential(
   (0): Linear(in_features=123, out_features=1, bias=True)
 )}

# Master algorithm

In [66]:
from transformer import Transformer

In [8]:
class DivideModel(nn.Module):
    def __init__(self, original_model, layer=-1):
        super(DivideModel, self).__init__()
        self.num_ftrs = original_model.fc.in_features
        self.num_class = original_model.fc.out_features
        self.features = None # Change features here. Example: nn.Sequential(*list(o"riginal_model.children())[:layer])
        self.classifier = None # Change predictor here. Example: nn.Sequential(*list(original_model.children())[layer:])
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, self.num_ftrs)
        x = self.classifier(x)
        x = x.view(-1, self.num_class)
        return x

In [9]:
def StudentModel(device, load_path=None):
    model = None# Place a backbone model here
    ## change some model properties if needed; Ex.:
    ## num_ftrs = model.fc.in_features
    ## model.fc = nn.Linear(num_ftrs, num_classes)
    if load_path:
        model.load_state_dict(torch.load(load_path))
    model = DivideModel(model)
    model = model.to(device)
    return model

In [14]:
def features_mask(features, domains, climate):
    mask = (domains == climate).nonzero()
    features[(domains == climate).nonzero()[0]] = torch.zeros_like(features[0])
    return features

In [28]:
expert = source_domains_experts['dry']
climate = 'dry'

In [27]:
[features_mask(expert(x_sup.float()).detach(), domain, climate) for climate, expert in source_domains_experts.items()]

[tensor([[   0.0000],
         [4061.8696],
         [3191.2437],
         [4059.6096],
         [   0.0000],
         [4234.9868],
         [4177.4487],
         [3901.5540],
         [4044.8813],
         [   0.0000],
         [3368.3853],
         [   0.0000],
         [3109.7329],
         [3882.3821],
         [3783.6377],
         [4123.1855],
         [4020.4978],
         [4085.1965],
         [3799.0383],
         [      nan],
         [   0.0000],
         [3331.3696],
         [   0.0000],
         [   0.0000],
         [3915.4246],
         [4016.7466],
         [3193.8394],
         [3339.3872],
         [3094.5400],
         [   0.0000],
         [4004.0603],
         [4033.5818]]),
 tensor([[4060.4006],
         [   0.0000],
         [   0.0000],
         [   0.0000],
         [4030.6379],
         [   0.0000],
         [   0.0000],
         [3901.5540],
         [   0.0000],
         [4121.7065],
         [   0.0000],
         [3775.0737],
         [   0.0000],
        

In [34]:
for x, y_true, metadata in train_loader:
    domain = np.array(metadata['climate'])

    sup_size = x.shape[0]//2
    x_sup = x[:sup_size]
    y_sup = y_true[:sup_size]
    x_que = x[sup_size:]
    y_que = y_true[sup_size:]
    domain = domain[:sup_size]

    with torch.no_grad():
        logits = torch.stack(
            [
            features_mask(expert(x_sup.float()).detach(), domain, climate)
            for climate, expert in source_domains_experts.items()
            ], dim=-1)
        logits = logits.permute((0, 2, 1))
    break
logits.shape

torch.Size([32, 3, 1])

In [17]:
for x, y_true, metadata in train_loader:
        
    domain = np.array(metadata['climate'])
    

    sup_size = x.shape[0]//2
    x_sup = x[:sup_size]
    y_sup = y_true[:sup_size]
    x_que = x[sup_size:]
    y_que = y_true[sup_size:]
    domain = domain[:sup_size]

    break

In [26]:
for climate, expert in source_domains_experts.items():
    break

In [None]:
transformer.att

In [99]:
transformer = Transformer(1, 2, 1, 128)

In [102]:
selector = fa_selector(1, 2, 1, 128)

In [69]:
for x, y_true, metadata in train_loader:
    domain = np.array(metadata['climate'])

    sup_size = x.shape[0]//2
    x_sup = x[:sup_size]
    y_sup = y_true[:sup_size]
    x_que = x[sup_size:]
    y_que = y_true[sup_size:]
    domain = domain[:sup_size]

    with torch.no_grad():
        logits = torch.stack(
            [
            features_mask(expert(x_sup.float()).detach(), domain, climate)
            for climate, expert in source_domains_experts.items()
            ], dim=-1)
        logits = logits.permute((0, 2, 1))
    break
logits.shape

torch.Size([32, 3, 1])

In [None]:
def l2_loss(input, target):
    loss = torch.square(target - input)
    loss = torch.mean(loss)
    return loss

def train_epoch(selector, selector_name, source_domains_experts, student, student_name, 
                train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
                device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3,
                batch_size=256, sup_size=24, test_way='id', save=False,
                root_dir='data'):
    for _, expert in source_domains_experts:
        expert.eval()
    
    student_ce = nn.BCEWithLogitsLoss()

    
    features = student.features
    head = student.classifier
    features.to(device)
    head.to(device)
    
    all_params = list(features.parameters()) + list(head.parameters())
    optimizer_s = optim.Adam(all_params, lr=slr)
    optimizer_t = optim.Adam(selector.parameters(), lr=tlr)
    
    i = 0
    
    losses = []
    
    iter_per_epoch = len(train_loader)
        
    for x, y_true, metadata in train_loader:
        selector.eval()
        head.eval()
        features.eval()
        
        domain = np.array(metadata['climate'])
        
    
        sup_size = x.shape[0]//2
        x_sup = x[:sup_size]
        y_sup = y_true[:sup_size]
        x_que = x[sup_size:]
        y_que = y_true[sup_size:]
        domain = domain[:sup_size]

        x_sup = x_sup.to(device)
        y_sup = y_sup.to(device)
        x_que = x_que.to(device)
        y_que = y_que.to(device)
        

        _squeeze = True
        with torch.no_grad():
            logits = torch.stack(
                [
                features_mask(expert(x_sup).detach(), domain, climate)
                for climate, expert in source_domains_experts.items()
                ], dim=-1)
            ### Expert input: [BS, 123]; Expert output: [BS, N]
            ### logits -> [BS, N, 3].
            logits = logits.permute((0,2,1))
                
            
            #logits = torch.stack([expert(x_sup).detach() for expert in experts_list], dim=-1)
            #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]])
            #
            #logits = mask_feat(logits, mask, len(models_list), exclude=True)
        
        t_out = selector.get_feat(logits)  

        task_model = features.clone()
        task_model.module.eval()
        feat = task_model(x_que)
        feat = feat.view(feat.shape[0], -1)
        out = head(feat)
        with torch.no_grad():
            loss_pre = student_ce(out, y_que.unsqueeze(-1).float()).item()/x_que.shape[0]
        
        feat = task_model(x_sup)
        feat = feat.view_as(t_out)

        inner_loss = l2_loss(feat, t_out)
        task_model.adapt(inner_loss)
        
        x_que = task_model(x_que)
        x_que = x_que.view(x_que.shape[0], -1)
        s_que_out = head(x_que)
        s_que_loss = student_ce(s_que_out, y_que.unsqueeze(-1).float())
        #t_sup_loss = teacher_ce(t_out, y_sup)
        
        s_que_loss.backward()
        
        optimizer_s.step()
        optimizer_t.step()
        optimizer_s.zero_grad()
        optimizer_t.zero_grad()
        
        ### Print some validation info
        ### Code here
        ###

        losses.append(s_que_loss.item()/x_que.shape[0])
        
            
        i += 1
    return None

In [None]:
def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device,
             batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30,
             decayRate=0.96, save=False, test_way='ood', root_dir='data'):
    
    train_loader = get_data_loader()
    for epoch in range(num_epochs):
        some_train_loss_value = train_epoch(selector, selector_name, models_list, student, student_name, 
                                train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
                                device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr,
                                batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save,
                                root_dir=root_dir) # need to remove some input variables
        some_eval_loss_value = eval(selector, models_list, student, sup_size, device=device, 
                    ilr=ilr, test=False, progress=False, uniform_over_groups=False,
                    root_dir=root_dir)

        ### 
        # print results
        # save model

        tlr = tlr*decayRate
        slr = slr*decayRate