In [1]:
import os
import sys
sys.path.remove('/home/jovyan/.imgenv-lm-poly-0/lib/python3.7/site-packages')
os.environ['PYTHONPATH'] = '/home/user/conda/envs/ya/lib/python3.10/site-packages'

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from glob import glob

In [2]:
CUDA_DEV = 0
NUM_TAGS = 256

In [3]:
df_train = pd.read_csv('train.csv')
df_test = pd.read_csv('test.csv')

In [4]:
from collections import Counter

tags = [[int(i) for i in x.split(',')] for x in df_train.tags.values]
dict_tags = {}
for cls_tags in tags:
    for c in cls_tags:
        if c not in dict_tags.keys():
            dict_tags[c] = Counter(cls_tags)
        else:
            dict_tags[c].update(Counter(cls_tags))
            
for tag in dict_tags.keys():
    del dict_tags[tag][tag]
    n = np.sum(list(dict_tags[tag].values()))
    for t in dict_tags[tag].keys():
        dict_tags[tag][t] = dict_tags[tag][t]/n

In [5]:
track_idx2embeds = {}
for fn in tqdm(glob('track_embeddings/*')):
    name = fn.split('/')[1].split('.')[0]
    if name == "track_embeddings":
        continue
    track_idx = int(name)
    embeds = np.load(fn)
    track_idx2embeds[track_idx] = embeds

100%|██████████| 76715/76715 [03:46<00:00, 338.22it/s]


In [6]:
class TaggingDataset(Dataset):
    def __init__(self, df, track_idx2embeds, aug=0, testing=False, label_smoothing = 0):
        self.df = df
        self.testing = testing
        self.aug = aug
        self.track_idx2embeds = track_idx2embeds
        self.label_smoothing = label_smoothing
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        track_idx = row.track
        embeds = self.track_idx2embeds[track_idx]
        if self.testing:
            return track_idx, embeds
        tags = [int(x) for x in row.tags.split(',')]
        target = np.zeros(NUM_TAGS)
        target[tags] = 1
        if self.label_smoothing > 0:
            eps = self.label_smoothing / 256
            target = target * (1 - self.label_smoothing) + eps
        
        if np.random.choice([0, 1], p=[1 - self.aug, self.aug]):
            s = np.random.uniform(0.0, 0.4)
            e = np.random.uniform(s+0.1, 1)
            s = int(s * embeds.shape[0])
            e = int(e * embeds.shape[0])
            embeds = embeds[s:e]
        
        return track_idx, embeds, target

In [7]:
train_dataset = TaggingDataset(df_train[:-1000], track_idx2embeds=track_idx2embeds, aug=0.6)
val_dataset = TaggingDataset(df_train[-1000:], track_idx2embeds=track_idx2embeds)

test_dataset = TaggingDataset(df_test, testing=True, track_idx2embeds=track_idx2embeds)

In [8]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim=768, mult=4, p=0.0):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * mult),
            nn.Dropout(p),
            nn.GELU(),
            nn.Linear(emb_dim * mult, emb_dim)
        )

    def forward(self, x):
        return self.fc(x)
    
class AttentionPooling(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.LayerNorm(embedding_size),
            nn.GELU(),
            nn.Linear(embedding_size, 1)
        )

    def forward(self, x, mask=None):
        attn_logits = self.attn(x)
        if mask is not None:
            attn_logits[mask] = -float('inf')
        attn_weights = torch.softmax(attn_logits, dim=1)
        x = x * attn_weights
        x = x.sum(dim=1)
        return x
    
class Network(nn.Module):
    def __init__(
        self,
        num_classes = NUM_TAGS,
        input_dim = 768,
        hidden_dim = 512
    ):
        super().__init__()
        self.num_classes = num_classes
        self.proj = FeedForward(input_dim)
        self.bn = nn.BatchNorm1d(input_dim)
        self.ln = nn.LayerNorm(input_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, activation="gelu", batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        self.poooling = AttentionPooling(input_dim)
        self.fc = nn.Linear(input_dim, num_classes)
               
    def forward(self, embeds):
        embeds = self.proj(embeds)
        src_key_padding_mask = (embeds.mean(-1) == -1)
        embeds = self.ln(embeds)
        x = self.transformer_encoder(embeds, src_key_padding_mask=src_key_padding_mask)
        x = self.bn(self.poooling(x, mask=src_key_padding_mask))
        outs = self.fc(x)
        return outs

In [9]:
from torch.nn.utils.rnn import pad_sequence

def predict(model, loader, max_length):
    model.eval()
    track_idxs = []
    predictions = []
    with torch.no_grad():
        for data in loader:
            track_idx, embeds = data
            embeds = [x.to(CUDA_DEV) for x in embeds]
            embeds = pad_sequence(embeds, padding_value=-1, batch_first=True)[:, :max_length, :]
            pred_logits = model(embeds)
            pred_probs = torch.sigmoid(pred_logits)
            predictions.append(pred_probs.cpu().numpy())
            track_idxs.append(track_idx.numpy())
    predictions = np.vstack(predictions)
    track_idxs = np.vstack(track_idxs).ravel()
    return track_idxs, predictions

In [10]:
from tqdm import tqdm

criterion = nn.BCEWithLogitsLoss().to(CUDA_DEV)

def predict_train(model, loader, max_length):
    model.eval()
    track_idxs = []
    predictions = []
    targets = []
    loss = 0
    with torch.no_grad():
        for data in loader:
            track_idx, embeds, target = data
            embeds = [x.to(CUDA_DEV) for x in embeds]
            embeds = pad_sequence(embeds, padding_value=-1, batch_first=True)[:, :max_length, :]
            pred_logits = model(embeds)
            pred_probs = torch.sigmoid(pred_logits)

            predictions.append(pred_probs.cpu().numpy())
            track_idxs.append(track_idx.numpy())
            targets.append(target.numpy())
    predictions = np.vstack(predictions)
    targets = np.vstack(targets)
    track_idxs = np.vstack(track_idxs).ravel()
    return track_idxs, predictions, targets

In [11]:
def collate_fn(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    targets = np.vstack([x[2] for x in b])
    targets = torch.from_numpy(targets)
    return track_idxs, embeds, targets

def collate_fn_test(b):
    track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
    embeds = [torch.from_numpy(x[1]) for x in b]
    return track_idxs, embeds

In [12]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn_test)

In [14]:
import sklearn.metrics
from sklearn.model_selection import KFold
from transformers import set_seed

paths_last = [
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908",
              
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323",
        ]

paths = [
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908",
    
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049",
        "./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42",
        "./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323",
    
        f"./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123", 
        f"./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777",
        f"./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123",
        f"./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323",
        f"./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999",]

meta_features = {}
meta_targets = {}
meta_features_test = {}
    
for k, path in tqdm(enumerate(paths)):
    seed = int(path.split('_')[-1])
    set_seed(seed)
    max_length = int(path.split('_')[-9])
    print(seed, max_length)
    
    preds_val = {}
    targets_val = {}
    #preds_test = []
    
    kf = KFold(n_splits=10, random_state=seed, shuffle=True)
    folds = list(kf.split(df_train))
    
    for model_path in os.listdir(path):
        if not model_path.endswith('.pt'):
            continue
        if (path in paths_last) and (not model_path.startswith('last')):
            continue
            
        model = Network()
        model = model.to(CUDA_DEV)
        model.load_state_dict(torch.load(f"{path}/{model_path}"))
        
        fold_i = int(model_path.split('_')[-1].split('.')[0])
        train_index, test_index = folds[fold_i]

        val_dataset = TaggingDataset(df_train.iloc[test_index], track_idx2embeds=track_idx2embeds)
        val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)

        track_idxs, predictions, targets = predict_train(model, val_dataloader, max_length=max_length)
        #track_idxs_test, predictions_test = predict(model, test_dataloader, max_length=max_length)
        
        for i, c in enumerate(predictions.argmax(-1)):
            probs = np.array([1 + dict_tags[c].get(t, 0) for t in np.arange(predictions.shape[1])])
            probs[c] = 2
            predictions[i] = predictions[i] * probs
            predictions[i] /= predictions[i].sum()
        
        ap = sklearn.metrics.average_precision_score(targets, predictions)
        print(f"Fold: {fold_i}, AP: {ap}, Seed: {seed}, model_path: {path}/{model_path}")

        for j, p, t in zip(track_idxs, predictions, targets):
            preds_val[j] = p
            targets_val[j] = t
                
        #preds_test.append(predictions_test)
        
#     predictions = np.mean(preds_test, axis=0)
#     for i, c in enumerate(predictions.argmax(-1)):
#         probs = np.array([1 + dict_tags[c].get(t, 0) for t in np.arange(predictions.shape[1])])
#         probs[c] = 2
#         predictions[i] = predictions[i] * probs
#         predictions[i] /= predictions[i].sum()
    
    if (path in paths_last):
        path_csv = f"{path}/prediction_last_mean_folds.csv"
    else:
        path_csv = f"{path}/prediction_mean_folds.csv"
    df = pd.read_csv(path_csv)
    predictions = np.array([[float(a) for a in x.split(',')] for x in df.prediction.values])
    meta_features_test[k] = predictions
    meta_targets[k] = targets_val
    meta_features[k] = preds_val

0it [00:00, ?it/s]

394529034 80
Fold: 0, AP: 0.27884541879629987, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_0.pt
Fold: 1, AP: 0.2828230255433555, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_1.pt
Fold: 2, AP: 0.2745290230899722, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_2.pt
Fold: 3, AP: 0.2691909038746434, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_3.pt
Fold: 4, AP: 0.28602370096100144, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_4.pt
Fold: 5, AP: 0.27785399431951124, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_394529034/last_model_5.pt
Fold: 6, AP: 0.28111623719841894, Seed: 394529034, model_path: ./workdir/final_es/12_6_0.6

1it [01:06, 66.40s/it]

7777 80
Fold: 0, AP: 0.27635320139053154, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_0.pt
Fold: 1, AP: 0.27488561553025626, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_1.pt
Fold: 2, AP: 0.2704427621074261, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_2.pt
Fold: 3, AP: 0.2743421552679267, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_3.pt
Fold: 4, AP: 0.2766762182109663, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_4.pt
Fold: 5, AP: 0.27400464772267685, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_5.pt
Fold: 6, AP: 0.27949451395125047, Seed: 7777, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_7777/last_model_6.pt
Fold: 7, AP: 0.

2it [02:13, 66.61s/it]

9 80
Fold: 0, AP: 0.27922566105926, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_0.pt
Fold: 1, AP: 0.27787307266080324, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_1.pt
Fold: 2, AP: 0.28383336755227945, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_2.pt
Fold: 3, AP: 0.2801535177687319, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_3.pt
Fold: 4, AP: 0.28287155656859797, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_4.pt
Fold: 5, AP: 0.2744968366064029, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_5.pt
Fold: 6, AP: 0.2729221966926413, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_9/last_model_6.pt
Fold: 7, AP: 0.28054436745318206, Seed: 9, model_path: ./workdir/final/12_6_0.6_80_6

3it [03:19, 66.47s/it]

928431908 80
Fold: 0, AP: 0.2708670838566745, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_0.pt
Fold: 1, AP: 0.2647746160165354, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_1.pt
Fold: 2, AP: 0.2837612439285485, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_2.pt
Fold: 3, AP: 0.2812777853517513, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_3.pt
Fold: 4, AP: 0.2725020985569577, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_4.pt
Fold: 5, AP: 0.2783112280253456, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_928431908/last_model_5.pt
Fold: 6, AP: 0.2877582490803553, Seed: 928431908, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_

4it [04:27, 67.25s/it]

12312049 80
Fold: 0, AP: 0.28413138680224204, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_0.pt
Fold: 1, AP: 0.26863256473541053, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_1.pt
Fold: 2, AP: 0.2809035723295671, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_2.pt
Fold: 3, AP: 0.27657602660628455, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_3.pt
Fold: 4, AP: 0.26998564128210556, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_4.pt
Fold: 5, AP: 0.2774328575258035, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12312049/last_model_5.pt
Fold: 6, AP: 0.2763971420566914, Seed: 12312049, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_1

5it [05:40, 69.02s/it]

3490394039 80
Fold: 0, AP: 0.27814608115632644, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_0.pt
Fold: 1, AP: 0.28009617863581004, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_1.pt
Fold: 2, AP: 0.27418945747605095, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_2.pt
Fold: 3, AP: 0.2807892971215152, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_3.pt
Fold: 4, AP: 0.27624670488140957, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_4.pt
Fold: 5, AP: 0.28063382735300174, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_3490394039/last_model_5.pt
Fold: 6, AP: 0.2774442662151546, Seed: 3490394039, model_path: ./workdir/final/12_6_0.6_80_64

6it [06:47, 68.39s/it]

42 80
Fold: 0, AP: 0.2731120156019412, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_0.pt
Fold: 1, AP: 0.27242955870342, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_1.pt
Fold: 2, AP: 0.2779880755365984, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_2.pt
Fold: 3, AP: 0.2788893159989052, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_3.pt
Fold: 4, AP: 0.2793577440815452, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_4.pt
Fold: 5, AP: 0.2763394670510687, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_5.pt
Fold: 6, AP: 0.27448640289546344, Seed: 42, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_42/last_model_6.pt
Fold: 7, AP: 0.2849842164472245, Seed: 42, model_p

7it [07:59, 69.77s/it]

12323 80
Fold: 0, AP: 0.27698466430384455, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_0.pt
Fold: 1, AP: 0.2804283556320519, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_1.pt
Fold: 2, AP: 0.2767936106237089, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_2.pt
Fold: 3, AP: 0.27396902414318847, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_3.pt
Fold: 4, AP: 0.2720107530612804, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_4.pt
Fold: 5, AP: 0.27870565746962955, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_5.pt
Fold: 6, AP: 0.28014423344812966, Seed: 12323, model_path: ./workdir/final_es/12_6_0.6_80_64_10_50_3e-05_1e-05_10_1e-07_12323/last_model_6.pt


8it [09:11, 70.46s/it]

123 64
Fold: 0, AP: 0.2693906611888211, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_0.pt
Fold: 1, AP: 0.28098337616629604, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_1.pt
Fold: 2, AP: 0.2707535993010135, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_2.pt
Fold: 3, AP: 0.2811971765589971, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_3.pt
Fold: 4, AP: 0.2677315776862632, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_4.pt
Fold: 5, AP: 0.2678676961238029, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_5.pt
Fold: 6, AP: 0.2696801092385323, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_6.pt
Fold: 7, AP: 0.26889339584209004, Seed: 123, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_123/model_7.pt

9it [10:15, 68.36s/it]

777 64
Fold: 0, AP: 0.2670125132140879, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_0.pt
Fold: 1, AP: 0.2746262511047279, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_1.pt
Fold: 2, AP: 0.26615747998557854, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_2.pt
Fold: 3, AP: 0.27265001333945094, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_3.pt
Fold: 4, AP: 0.274708643701947, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_4.pt
Fold: 5, AP: 0.2634172258501424, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_5.pt
Fold: 6, AP: 0.2741152957384707, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_6.pt
Fold: 7, AP: 0.27819836149459776, Seed: 777, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_777/model_7.pt

10it [11:19, 67.08s/it]

99999 64
Fold: 0, AP: 0.2710593848093555, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_0.pt
Fold: 1, AP: 0.26936157190204435, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_1.pt
Fold: 2, AP: 0.27136938401477356, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_2.pt
Fold: 3, AP: 0.27455449200039667, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_3.pt
Fold: 4, AP: 0.2621891587618945, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_4.pt
Fold: 5, AP: 0.27103314729631356, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_5.pt
Fold: 6, AP: 0.268289571201213, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3e-05_1e-05_10_1e-07_99999/model_6.pt
Fold: 7, AP: 0.27560189711860794, Seed: 99999, model_path: ./workdir/12_6_0.6_64_64_10_100_3

11it [12:22, 65.73s/it]

394529034 80
Fold: 0, AP: 0.26870576543306957, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_0.pt
Fold: 1, AP: 0.2779891018546478, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_1.pt
Fold: 2, AP: 0.27128695449268775, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_2.pt
Fold: 3, AP: 0.26615632071020445, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_3.pt
Fold: 4, AP: 0.2790239617563691, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_4.pt
Fold: 5, AP: 0.2775200690656471, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_5.pt
Fold: 6, AP: 0.2753753201379682, Seed: 394529034, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_394529034/model_6.pt
Fold: 7, AP: 0.27566541631395086

12it [13:29, 66.27s/it]

123 80
Fold: 0, AP: 0.26586786456777756, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_0.pt
Fold: 1, AP: 0.274374636901471, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_1.pt
Fold: 2, AP: 0.2735989392867765, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_2.pt
Fold: 3, AP: 0.2817634353593735, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_3.pt
Fold: 4, AP: 0.2684641569159902, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_4.pt
Fold: 5, AP: 0.27313762807653774, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_5.pt
Fold: 6, AP: 0.27270246968328643, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_6.pt
Fold: 7, AP: 0.2686550569392982, Seed: 123, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_7.pt

13it [14:40, 67.60s/it]

123 80
Fold: 0, AP: 0.2664262783388489, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_0.pt
Fold: 1, AP: 0.2804560942257841, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_1.pt
Fold: 2, AP: 0.27698059689839455, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_2.pt
Fold: 3, AP: 0.28420786983658575, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_3.pt
Fold: 4, AP: 0.2679084592920209, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_4.pt
Fold: 5, AP: 0.2688048120290476, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_5.pt
Fold: 6, AP: 0.27378583318778227, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_123/model_6.pt
Fold: 7, AP: 0.2752832061418351, Seed: 123, model_path: ./workdir/ls/12_6_0.6_80_64_10_100_3e-05_1e-0

14it [15:56, 70.12s/it]

777 80
Fold: 0, AP: 0.2676094031621189, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_0.pt
Fold: 1, AP: 0.27262403776614447, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_1.pt
Fold: 2, AP: 0.263804085472985, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_2.pt
Fold: 3, AP: 0.26965626191760045, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_3.pt
Fold: 4, AP: 0.27444034845953136, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_4.pt
Fold: 5, AP: 0.2672633516727342, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_5.pt
Fold: 6, AP: 0.27327148273270896, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_6.pt
Fold: 7, AP: 0.27839494229784467, Seed: 777, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_777/model_7.

15it [17:11, 71.59s/it]

1231238 80
Fold: 0, AP: 0.28117396080362816, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_0.pt
Fold: 1, AP: 0.27114924318980876, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_1.pt
Fold: 2, AP: 0.26928243020020304, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_2.pt
Fold: 3, AP: 0.275087020561557, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_3.pt
Fold: 4, AP: 0.27313623020107136, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_4.pt
Fold: 5, AP: 0.2661730038511637, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_5.pt
Fold: 6, AP: 0.27717678251237865, Seed: 1231238, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_1231238/model_6.pt
Fold: 7, AP: 0.2764084117186287, Seed: 1231238, model_path: .

16it [18:27, 72.80s/it]

12323 80
Fold: 0, AP: 0.27800373330926265, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_0.pt
Fold: 1, AP: 0.2771388453836147, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_1.pt
Fold: 2, AP: 0.2702892891909206, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_2.pt
Fold: 3, AP: 0.2713507338094403, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_3.pt
Fold: 4, AP: 0.2680784748012526, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_4.pt
Fold: 5, AP: 0.27203389487318796, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_5.pt
Fold: 6, AP: 0.2746276642254449, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_12323/model_6.pt
Fold: 7, AP: 0.2746939787241912, Seed: 12323, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-

17it [19:41, 73.39s/it]

99999 80
Fold: 0, AP: 0.27121578293776477, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_0.pt
Fold: 1, AP: 0.27413693266749173, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_1.pt
Fold: 2, AP: 0.2673873955551328, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_2.pt
Fold: 3, AP: 0.27685450351330837, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_3.pt
Fold: 4, AP: 0.26167061809967374, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_4.pt
Fold: 5, AP: 0.2702317776337357, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_5.pt
Fold: 6, AP: 0.25951624593590544, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3e-05_1e-05_10_1e-07_99999/model_6.pt
Fold: 7, AP: 0.273004212923609, Seed: 99999, model_path: ./workdir/12_6_0.6_80_64_10_100_3

18it [20:56, 69.82s/it]


In [15]:
track_idxs_test, _ = predict(model, test_dataloader, max_length=max_length) 

In [16]:
X = []
y = []
X_test_ = []
track_idxs = np.sort(list(meta_features[0].keys()))
for seed in meta_features.keys():
    X_i = [meta_features[seed][idx] for idx in track_idxs]
    y_i = [meta_targets[seed][idx] for idx in track_idxs]
    X.append(X_i)
    y.append(y_i)
    
    X_test_i = meta_features_test[seed]
    X_test_.append(X_test_i)

In [17]:
X_test = np.hstack(X_test_)
X = np.hstack(X)
y = np.array(y[0])

In [18]:
X.shape, y.shape, X_test.shape

((51134, 4608), (51134, 256), (25580, 4608))

In [19]:
import numpy as np
from sklearn.datasets import make_multilabel_classification
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import ElasticNetCV, RidgeCV
from transformers import set_seed
set_seed(42)


probs = []
coefs = []

for c in tqdm(range(256)):
    model = ElasticNetCV(cv=10, random_state=42)
    clf = model.fit(X[:, c::256], y[:, c])
    coefs.append(clf.coef_)
    probs.append((clf.coef_[None, :] * X_test[:, c::256]).sum(-1) / clf.coef_.sum())

100%|██████████| 256/256 [03:44<00:00,  1.14it/s]


In [20]:
np.mean(coefs, axis=0)

array([0.23585235, 0.23420078, 0.23979515, 0.22679715, 0.23709166,
       0.23807167, 0.24267697, 0.21780336, 0.19627948, 0.2203923 ,
       0.22394888, 0.20279156, 0.2054293 , 0.18239388, 0.18671122,
       0.22973692, 0.21306786, 0.22827692], dtype=float32)

In [21]:
predictions = np.vstack(probs).T
predictions.shape

(25580, 256)

In [22]:
predictions.min(), predictions.max()

(-4.315174664313027e-05, 0.6768879981779364)

In [23]:
predictions[predictions<0] = 0

In [24]:
predictions.min(), predictions.max()

(0.0, 0.6768879981779364)

In [25]:
for i, c in enumerate(predictions.argmax(-1)):
    probs = np.array([1 + dict_tags[c].get(t, 0) for t in np.arange(predictions.shape[1])])
    probs[c] = 2
    predictions[i] = predictions[i] * probs
    predictions[i] /= predictions[i].sum()

In [26]:
predictions_df = pd.DataFrame([
    {'track': track, 'prediction': ','.join([str(p) for p in probs])}
    for track, probs in zip(track_idxs_test, predictions)
])
predictions_df.to_csv(f'./workdir/prediction_stacking_total.csv', index=False)

In [27]:
predictions[0].argsort()

array([154, 199,  91, 224,  77, 115, 250, 233, 223, 140, 217, 171,  88,
        80, 212, 218, 220, 136, 156,  82,  57, 242, 216, 191,  66, 119,
       211,  90, 149, 111, 163, 165, 100, 255,  49, 142, 150, 196, 132,
       152, 125, 185, 179, 203, 213,  86, 206, 104, 222,  79,  29, 237,
       249, 126, 147, 113, 127, 230, 197,  73, 205,  59, 130,  78, 254,
       229, 159, 207, 189, 226, 184, 146, 235, 247, 253, 182, 210,  53,
       131, 139, 153,  89, 102, 166, 133,  83, 208,  92,  60, 227,  56,
       219, 251,  64, 188,  51, 128,  71, 143,  42, 161, 144, 105, 244,
        48, 108, 193, 138, 200,  31, 238,  55,  69, 239,  65, 245, 114,
       181, 221,  10, 168, 172, 120, 135,  96,  94, 190,  18, 178,  44,
        52,  41,  27,  14,  58, 109, 157, 107, 231,  62,  13,  75, 141,
        16, 110,  46, 124,  67, 186,  70,  22,   4,  21, 129,  68,  40,
       243, 174,  17, 118,  87,   7, 180, 148,  37,  25,  33, 164, 209,
        99, 106,  45,  97,  20, 176, 173,  50, 240,  34, 192, 11