In [1]:
import torch
from torch.nn import functional as F
import numpy as np
# from tqdm import tqdm_notebook as tqdm
from tqdm import tqdm
from rsna19.configs.second_level import Config
from sklearn.metrics import log_loss
import pandas as pd
from scipy.signal import windows

In [2]:
# train_folds = [0, 1, 2, 3] # for testing 
train_folds = [0, 1, 2, 3, 4]  # for stage 2 submission
val_folds = [4]

cache_dir = '/home/dmytro/ml/kaggle-rsna-2019/output/cache'

train_x = torch.cat([torch.tensor(np.load(f'{cache_dir}/fold{f}/x.npy'), dtype=torch.float32) for f in train_folds], dim=0)
train_y = torch.cat([torch.tensor(np.load(f'{cache_dir}/fold{f}/y.npy'), dtype=torch.float32) for f in train_folds], dim=0)
val_x = torch.cat([torch.tensor(np.load(f'{cache_dir}/fold{f}/x.npy'), dtype=torch.float32) for f in val_folds], dim=0)
val_y = torch.cat([torch.tensor(np.load(f'{cache_dir}/fold{f}/y.npy'), dtype=torch.float32) for f in val_folds], dim=0)

n_models = 5
class_weights = torch.tensor([1, 1, 1, 1, 1, 2], dtype=torch.float32) * 6 / 7
loss_fn = F.binary_cross_entropy

In [3]:
# undo sigmoid

# train_x[train_x > 0] = torch.log(train_x[train_x > 0] / (1-train_x[train_x > 0]))
# val_x[val_x > 0] = torch.log(val_x[val_x > 0] / (1-val_x[val_x > 0]))
# loss_fn = F.binary_cross_entropy_with_logits

In [4]:
print('train')
preds = []
for model_id in range(n_models):
    preds.append(train_x[:, model_id*30+12:model_id*30+18])
    loss = loss_fn(preds[-1], train_y, weight=class_weights)
    print(f'model {model_id}: {loss}')

mean_preds = torch.mean(torch.stack(preds), dim=0)
loss = loss_fn(mean_preds, train_y, weight=class_weights)
print(f'averaged ensemble: {loss}')

print('\nval')
preds = []
for model_id in range(n_models):
    preds.append(val_x[:, model_id*30+12:model_id*30+18])
    loss = loss_fn(preds[-1], val_y, weight=class_weights)
    print(f'model {model_id}: {loss}')

mean_preds = torch.mean(torch.stack(preds), dim=0)
loss = loss_fn(mean_preds, val_y, weight=class_weights)
print(f'averaged ensemble: {loss}')

train
model 0: 0.06113557517528534
model 1: 0.062465518712997437
model 2: 0.0627981498837471
model 3: 0.061474189162254333
model 4: 0.06128436699509621
averaged ensemble: 0.05732988193631172

val
model 0: 0.061941273510456085
model 1: 0.06328877061605453
model 2: 0.06281528621912003
model 3: 0.06388995796442032
model 4: 0.06146755814552307
averaged ensemble: 0.05845646187663078


In [5]:
# train on middle slice only

# train_x_1slice = []
# for model_id in range(n_models):
#     train_x_1slice.append(train_x[:, model_id*30+12:model_id*30+18])
# train_x_1slice = torch.cat(train_x_1slice, dim=1)

# val_x_1slice = []
# for model_id in range(n_models):
#     val_x_1slice.append(val_x[:, model_id*30+12:model_id*30+18])
# val_x_1slice = torch.cat(val_x_1slice, dim=1)

# train_x = train_x_1slice
# val_x = val_x_1slice

In [11]:
hidden = 128
features_out = 6

# model = torch.nn.Sequential(
#         torch.nn.Linear(train_x.shape[1], features_out)
# )

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = torch.nn.Linear(train_x.shape[1], features_out, bias=False)

    def forward(self, x):
        x = F.linear(x, torch.abs(self.w1.weight) / torch.sum(torch.abs(self.w1.weight), 1, keepdim=True), self.w1.bias)
#         x = F.linear(x, self.w1.weight / torch.sum(self.w1.weight, 1, keepdim=True), self.w1.bias)
        return torch.clamp(x, 0, 1)


model = Model()
print(model.w1.weight.shape)
# model = torch.nn.Sequential(
#         torch.nn.Linear(train_x.shape[1], hidden),
#         torch.nn.ReLU(),
# #         torch.nn.Dropout(0.2),
#         torch.nn.Linear(hidden, features_out),
# )

train_x = train_x.cuda()
train_y = train_y.cuda()
val_x = val_x.cuda()
val_y = val_y.cuda()
model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), 0.0001)
val_log_loss = 0
class_weights = class_weights.cuda()

for i in tqdm(range(10000)):
    optimizer.zero_grad()

    y_hat = model(train_x)
    loss = F.binary_cross_entropy(y_hat, train_y, weight=class_weights)
    loss.backward()
    optimizer.step()

    if i % 100 == 0:
        model.eval()
        val_y_hat = model(val_x)
        val_loss = F.binary_cross_entropy(val_y_hat, val_y, weight=class_weights)
        model.train()
        
        print(f'{i:04d}: train: {loss.item():.04f}, val: {val_loss.item():.04f}')


  0%|          | 0/10000 [00:00<?, ?it/s][A
  0%|          | 22/10000 [00:00<00:45, 218.01it/s][A

torch.Size([6, 150])
0000: train: 0.1309, val: 0.1316



  0%|          | 48/10000 [00:00<00:43, 228.48it/s][A
  1%|          | 73/10000 [00:00<00:42, 234.03it/s][A
  1%|          | 100/10000 [00:00<00:40, 242.22it/s][A
  1%|▏         | 127/10000 [00:00<00:39, 249.45it/s][A
  2%|▏         | 155/10000 [00:00<00:38, 256.62it/s][A

0100: train: 0.1205, val: 0.1211



  2%|▏         | 183/10000 [00:00<00:37, 262.33it/s][A
  2%|▏         | 211/10000 [00:00<00:36, 266.95it/s][A
  2%|▏         | 238/10000 [00:00<00:36, 266.88it/s][A

0200: train: 0.1092, val: 0.1097



  3%|▎         | 266/10000 [00:01<00:36, 269.31it/s][A
  3%|▎         | 295/10000 [00:01<00:35, 273.66it/s][A
  3%|▎         | 323/10000 [00:01<00:35, 275.25it/s][A
  4%|▎         | 351/10000 [00:01<00:34, 276.16it/s][A

0300: train: 0.0953, val: 0.0957



  4%|▍         | 379/10000 [00:01<00:34, 275.18it/s][A
  4%|▍         | 407/10000 [00:01<00:34, 275.07it/s][A
  4%|▍         | 435/10000 [00:01<00:34, 275.24it/s][A

0400: train: 0.0803, val: 0.0807



  5%|▍         | 463/10000 [00:01<00:34, 275.93it/s][A
  5%|▍         | 491/10000 [00:01<00:34, 275.62it/s][A
  5%|▌         | 519/10000 [00:01<00:34, 274.04it/s][A
  5%|▌         | 547/10000 [00:02<00:34, 274.14it/s][A

0500: train: 0.0689, val: 0.0694



  6%|▌         | 575/10000 [00:02<00:34, 273.91it/s][A
  6%|▌         | 603/10000 [00:02<00:34, 270.07it/s][A
  6%|▋         | 631/10000 [00:02<00:34, 269.59it/s][A

0600: train: 0.0639, val: 0.0647



  7%|▋         | 659/10000 [00:02<00:34, 272.24it/s][A
  7%|▋         | 687/10000 [00:02<00:34, 272.64it/s][A
  7%|▋         | 715/10000 [00:02<00:33, 273.40it/s][A
  7%|▋         | 743/10000 [00:02<00:33, 274.52it/s][A

0700: train: 0.0627, val: 0.0636



  8%|▊         | 771/10000 [00:02<00:33, 274.77it/s][A
  8%|▊         | 799/10000 [00:02<00:33, 273.01it/s][A
  8%|▊         | 827/10000 [00:03<00:33, 272.46it/s][A
  9%|▊         | 855/10000 [00:03<00:33, 273.12it/s][A

0800: train: 0.0622, val: 0.0631



  9%|▉         | 883/10000 [00:03<00:33, 271.22it/s][A
  9%|▉         | 911/10000 [00:03<00:33, 270.71it/s][A
  9%|▉         | 939/10000 [00:03<00:33, 271.46it/s][A

0900: train: 0.0618, val: 0.0626



 10%|▉         | 967/10000 [00:03<00:33, 273.64it/s][A
 10%|▉         | 995/10000 [00:03<00:32, 274.97it/s][A
 10%|█         | 1023/10000 [00:03<00:32, 274.39it/s][A
 11%|█         | 1051/10000 [00:03<00:32, 275.44it/s][A

1000: train: 0.0614, val: 0.0623



 11%|█         | 1079/10000 [00:03<00:32, 274.83it/s][A
 11%|█         | 1107/10000 [00:04<00:32, 275.62it/s][A
 11%|█▏        | 1135/10000 [00:04<00:32, 275.80it/s][A

1100: train: 0.0610, val: 0.0619



 12%|█▏        | 1163/10000 [00:04<00:32, 275.90it/s][A
 12%|█▏        | 1192/10000 [00:04<00:31, 278.09it/s][A
 12%|█▏        | 1220/10000 [00:04<00:31, 276.75it/s][A
 12%|█▏        | 1248/10000 [00:04<00:31, 275.55it/s][A

1200: train: 0.0606, val: 0.0615



 13%|█▎        | 1276/10000 [00:04<00:31, 275.07it/s][A
 13%|█▎        | 1304/10000 [00:04<00:31, 275.40it/s][A
 13%|█▎        | 1332/10000 [00:04<00:31, 274.95it/s][A

1300: train: 0.0603, val: 0.0611



 14%|█▎        | 1360/10000 [00:04<00:31, 276.25it/s][A
 14%|█▍        | 1388/10000 [00:05<00:31, 274.88it/s][A
 14%|█▍        | 1416/10000 [00:05<00:31, 273.05it/s][A
 14%|█▍        | 1444/10000 [00:05<00:31, 271.91it/s][A

1400: train: 0.0599, val: 0.0608



 15%|█▍        | 1472/10000 [00:05<00:31, 274.03it/s][A
 15%|█▌        | 1500/10000 [00:05<00:30, 274.89it/s][A
 15%|█▌        | 1528/10000 [00:05<00:30, 274.82it/s][A
 16%|█▌        | 1556/10000 [00:05<00:30, 275.79it/s][A

1500: train: 0.0595, val: 0.0604



 16%|█▌        | 1584/10000 [00:05<00:30, 272.92it/s][A
 16%|█▌        | 1612/10000 [00:05<00:31, 270.11it/s][A
 16%|█▋        | 1640/10000 [00:06<00:31, 268.32it/s][A

1600: train: 0.0592, val: 0.0601



 17%|█▋        | 1667/10000 [00:06<00:31, 266.65it/s][A
 17%|█▋        | 1694/10000 [00:06<00:31, 267.12it/s][A
 17%|█▋        | 1721/10000 [00:06<00:31, 265.28it/s][A
 17%|█▋        | 1749/10000 [00:06<00:30, 267.56it/s][A

1700: train: 0.0590, val: 0.0598



 18%|█▊        | 1777/10000 [00:06<00:30, 269.11it/s][A
 18%|█▊        | 1805/10000 [00:06<00:30, 269.97it/s][A
 18%|█▊        | 1833/10000 [00:06<00:30, 271.90it/s][A

1800: train: 0.0588, val: 0.0596



 19%|█▊        | 1861/10000 [00:06<00:29, 273.02it/s][A
 19%|█▉        | 1889/10000 [00:06<00:29, 272.92it/s][A
 19%|█▉        | 1917/10000 [00:07<00:29, 274.20it/s][A
 19%|█▉        | 1945/10000 [00:07<00:29, 273.63it/s][A

1900: train: 0.0586, val: 0.0594



 20%|█▉        | 1973/10000 [00:07<00:29, 272.47it/s][A
 20%|██        | 2001/10000 [00:07<00:29, 271.02it/s][A
 20%|██        | 2029/10000 [00:07<00:29, 271.87it/s][A

2000: train: 0.0584, val: 0.0592



 21%|██        | 2057/10000 [00:07<00:29, 271.90it/s][A
 21%|██        | 2085/10000 [00:07<00:29, 272.63it/s][A
 21%|██        | 2113/10000 [00:07<00:28, 272.96it/s][A
 21%|██▏       | 2141/10000 [00:07<00:28, 274.01it/s][A

2100: train: 0.0582, val: 0.0591



 22%|██▏       | 2169/10000 [00:07<00:28, 274.50it/s][A
 22%|██▏       | 2197/10000 [00:08<00:28, 274.60it/s][A
 22%|██▏       | 2225/10000 [00:08<00:28, 273.66it/s][A
 23%|██▎       | 2253/10000 [00:08<00:28, 274.37it/s][A

2200: train: 0.0581, val: 0.0589



 23%|██▎       | 2281/10000 [00:08<00:28, 275.15it/s][A
 23%|██▎       | 2309/10000 [00:08<00:28, 274.33it/s][A
 23%|██▎       | 2337/10000 [00:08<00:27, 275.59it/s][A

2300: train: 0.0580, val: 0.0588



 24%|██▎       | 2365/10000 [00:08<00:27, 274.25it/s][A
 24%|██▍       | 2393/10000 [00:08<00:27, 275.05it/s][A
 24%|██▍       | 2421/10000 [00:08<00:27, 273.16it/s][A
 24%|██▍       | 2449/10000 [00:08<00:27, 272.28it/s][A

2400: train: 0.0579, val: 0.0587



 25%|██▍       | 2477/10000 [00:09<00:27, 273.00it/s][A
 25%|██▌       | 2505/10000 [00:09<00:27, 272.77it/s][A
 25%|██▌       | 2533/10000 [00:09<00:27, 273.45it/s][A

2500: train: 0.0577, val: 0.0586



 26%|██▌       | 2561/10000 [00:09<00:27, 273.72it/s][A
 26%|██▌       | 2589/10000 [00:09<00:27, 274.07it/s][A
 26%|██▌       | 2617/10000 [00:09<00:26, 273.78it/s][A
 26%|██▋       | 2645/10000 [00:09<00:26, 273.46it/s][A

2600: train: 0.0577, val: 0.0585



 27%|██▋       | 2673/10000 [00:09<00:26, 272.63it/s][A
 27%|██▋       | 2701/10000 [00:09<00:27, 269.75it/s][A
 27%|██▋       | 2730/10000 [00:10<00:26, 273.52it/s][A

2700: train: 0.0576, val: 0.0584



 28%|██▊       | 2758/10000 [00:10<00:26, 273.73it/s][A
 28%|██▊       | 2787/10000 [00:10<00:26, 276.39it/s][A
 28%|██▊       | 2815/10000 [00:10<00:26, 275.28it/s][A
 28%|██▊       | 2843/10000 [00:10<00:25, 275.54it/s][A

2800: train: 0.0575, val: 0.0584



 29%|██▊       | 2871/10000 [00:10<00:25, 275.40it/s][A
 29%|██▉       | 2899/10000 [00:10<00:25, 275.31it/s][A
 29%|██▉       | 2927/10000 [00:10<00:25, 275.73it/s][A
 30%|██▉       | 2955/10000 [00:10<00:25, 274.94it/s][A

2900: train: 0.0575, val: 0.0583



 30%|██▉       | 2983/10000 [00:10<00:25, 275.69it/s][A
 30%|███       | 3011/10000 [00:11<00:25, 274.69it/s][A
 30%|███       | 3040/10000 [00:11<00:25, 276.64it/s][A

3000: train: 0.0575, val: 0.0583



 31%|███       | 3068/10000 [00:11<00:25, 276.48it/s][A
 31%|███       | 3096/10000 [00:11<00:24, 277.15it/s][A
 31%|███       | 3124/10000 [00:11<00:25, 274.72it/s][A
 32%|███▏      | 3152/10000 [00:11<00:25, 272.47it/s][A

3100: train: 0.0575, val: 0.0583



 32%|███▏      | 3180/10000 [00:11<00:25, 270.90it/s][A
 32%|███▏      | 3208/10000 [00:11<00:25, 267.20it/s][A
 32%|███▏      | 3235/10000 [00:11<00:25, 264.90it/s][A

3200: train: 0.0574, val: 0.0583



 33%|███▎      | 3263/10000 [00:11<00:25, 267.13it/s][A
 33%|███▎      | 3291/10000 [00:12<00:24, 268.74it/s][A
 33%|███▎      | 3319/10000 [00:12<00:24, 270.42it/s][A
 33%|███▎      | 3347/10000 [00:12<00:24, 271.35it/s][A

3300: train: 0.0574, val: 0.0582



 34%|███▍      | 3375/10000 [00:12<00:24, 272.58it/s][A
 34%|███▍      | 3403/10000 [00:12<00:24, 272.49it/s][A
 34%|███▍      | 3431/10000 [00:12<00:24, 273.55it/s][A

3400: train: 0.0574, val: 0.0582



 35%|███▍      | 3459/10000 [00:12<00:23, 273.56it/s][A
 35%|███▍      | 3487/10000 [00:12<00:23, 273.97it/s][A
 35%|███▌      | 3515/10000 [00:12<00:23, 272.49it/s][A
 35%|███▌      | 3543/10000 [00:13<00:23, 271.98it/s][A

3500: train: 0.0574, val: 0.0582



 36%|███▌      | 3571/10000 [00:13<00:23, 270.13it/s][A
 36%|███▌      | 3599/10000 [00:13<00:23, 267.31it/s][A
 36%|███▋      | 3626/10000 [00:13<00:23, 267.44it/s][A
 37%|███▋      | 3654/10000 [00:13<00:23, 269.89it/s][A

3600: train: 0.0574, val: 0.0582



 37%|███▋      | 3682/10000 [00:13<00:23, 271.55it/s][A
 37%|███▋      | 3710/10000 [00:13<00:23, 269.68it/s][A
 37%|███▋      | 3739/10000 [00:13<00:22, 273.70it/s][A

3700: train: 0.0574, val: 0.0582



 38%|███▊      | 3767/10000 [00:13<00:22, 273.10it/s][A
 38%|███▊      | 3795/10000 [00:13<00:22, 271.47it/s][A
 38%|███▊      | 3823/10000 [00:14<00:22, 271.68it/s][A
 39%|███▊      | 3851/10000 [00:14<00:22, 272.12it/s][A

3800: train: 0.0574, val: 0.0582



 39%|███▉      | 3879/10000 [00:14<00:22, 271.37it/s][A
 39%|███▉      | 3907/10000 [00:14<00:22, 271.37it/s][A
 39%|███▉      | 3935/10000 [00:14<00:22, 271.54it/s][A

3900: train: 0.0574, val: 0.0582



 40%|███▉      | 3963/10000 [00:14<00:22, 273.10it/s][A
 40%|███▉      | 3991/10000 [00:14<00:21, 273.77it/s][A
 40%|████      | 4019/10000 [00:14<00:21, 271.99it/s][A
 40%|████      | 4047/10000 [00:14<00:22, 270.14it/s][A

4000: train: 0.0574, val: 0.0582



 41%|████      | 4075/10000 [00:14<00:21, 271.19it/s][A
 41%|████      | 4103/10000 [00:15<00:21, 271.61it/s][A
 41%|████▏     | 4131/10000 [00:15<00:21, 270.68it/s][A

4100: train: 0.0574, val: 0.0582



 42%|████▏     | 4159/10000 [00:15<00:21, 271.78it/s][A
 42%|████▏     | 4187/10000 [00:15<00:21, 271.41it/s][A
 42%|████▏     | 4215/10000 [00:15<00:21, 267.93it/s][A
 42%|████▏     | 4243/10000 [00:15<00:21, 269.33it/s][A

4200: train: 0.0574, val: 0.0582



 43%|████▎     | 4271/10000 [00:15<00:21, 270.34it/s][A
 43%|████▎     | 4299/10000 [00:15<00:21, 269.94it/s][A
 43%|████▎     | 4326/10000 [00:15<00:21, 266.23it/s][A
 44%|████▎     | 4354/10000 [00:15<00:20, 270.15it/s][A

4300: train: 0.0574, val: 0.0582



 44%|████▍     | 4382/10000 [00:16<00:20, 271.65it/s][A
 44%|████▍     | 4410/10000 [00:16<00:20, 272.10it/s][A
 44%|████▍     | 4438/10000 [00:16<00:20, 270.86it/s][A

4400: train: 0.0574, val: 0.0582



 45%|████▍     | 4466/10000 [00:16<00:20, 269.22it/s][A
 45%|████▍     | 4493/10000 [00:16<00:20, 269.05it/s][A
 45%|████▌     | 4520/10000 [00:16<00:20, 268.69it/s][A
 45%|████▌     | 4548/10000 [00:16<00:20, 270.17it/s][A

4500: train: 0.0574, val: 0.0582



 46%|████▌     | 4576/10000 [00:16<00:20, 270.98it/s][A
 46%|████▌     | 4604/10000 [00:16<00:19, 271.07it/s][A
 46%|████▋     | 4632/10000 [00:17<00:19, 271.38it/s][A

4600: train: 0.0574, val: 0.0582



 47%|████▋     | 4660/10000 [00:17<00:19, 271.08it/s][A
 47%|████▋     | 4688/10000 [00:17<00:19, 271.58it/s][A
 47%|████▋     | 4716/10000 [00:17<00:19, 267.89it/s][A
 47%|████▋     | 4744/10000 [00:17<00:19, 269.35it/s][A

4700: train: 0.0574, val: 0.0582



 48%|████▊     | 4772/10000 [00:17<00:19, 271.38it/s][A
 48%|████▊     | 4800/10000 [00:17<00:19, 270.83it/s][A
 48%|████▊     | 4828/10000 [00:17<00:19, 270.63it/s][A

4800: train: 0.0574, val: 0.0582



 49%|████▊     | 4856/10000 [00:17<00:19, 266.32it/s][A
 49%|████▉     | 4883/10000 [00:17<00:19, 265.74it/s][A
 49%|████▉     | 4910/10000 [00:18<00:19, 266.02it/s][A
 49%|████▉     | 4938/10000 [00:18<00:18, 267.70it/s][A

4900: train: 0.0574, val: 0.0582



 50%|████▉     | 4965/10000 [00:18<00:18, 267.89it/s][A
 50%|████▉     | 4992/10000 [00:18<00:18, 268.40it/s][A
 50%|█████     | 5020/10000 [00:18<00:18, 269.56it/s][A
 50%|█████     | 5047/10000 [00:18<00:18, 268.89it/s][A

5000: train: 0.0574, val: 0.0582



 51%|█████     | 5074/10000 [00:18<00:18, 268.62it/s][A
 51%|█████     | 5101/10000 [00:18<00:18, 266.69it/s][A
 51%|█████▏    | 5129/10000 [00:18<00:18, 267.97it/s][A

5100: train: 0.0574, val: 0.0582



 52%|█████▏    | 5156/10000 [00:18<00:18, 267.21it/s][A
 52%|█████▏    | 5184/10000 [00:19<00:17, 268.97it/s][A
 52%|█████▏    | 5212/10000 [00:19<00:17, 269.64it/s][A
 52%|█████▏    | 5239/10000 [00:19<00:17, 268.63it/s][A

5200: train: 0.0574, val: 0.0582



 53%|█████▎    | 5266/10000 [00:19<00:17, 267.78it/s][A
 53%|█████▎    | 5294/10000 [00:19<00:17, 268.76it/s][A
 53%|█████▎    | 5321/10000 [00:19<00:17, 267.59it/s][A
 53%|█████▎    | 5349/10000 [00:19<00:17, 268.62it/s][A

5300: train: 0.0574, val: 0.0582



 54%|█████▍    | 5376/10000 [00:19<00:17, 268.14it/s][A
 54%|█████▍    | 5403/10000 [00:19<00:17, 266.37it/s][A
 54%|█████▍    | 5431/10000 [00:20<00:17, 268.05it/s][A

5400: train: 0.0574, val: 0.0582



 55%|█████▍    | 5459/10000 [00:20<00:16, 269.26it/s][A
 55%|█████▍    | 5487/10000 [00:20<00:16, 271.09it/s][A
 55%|█████▌    | 5515/10000 [00:20<00:16, 269.39it/s][A
 55%|█████▌    | 5542/10000 [00:20<00:16, 269.10it/s][A

5500: train: 0.0574, val: 0.0582



 56%|█████▌    | 5570/10000 [00:20<00:16, 269.78it/s][A
 56%|█████▌    | 5598/10000 [00:20<00:16, 271.78it/s][A
 56%|█████▋    | 5626/10000 [00:20<00:16, 272.72it/s][A
 57%|█████▋    | 5654/10000 [00:20<00:15, 273.69it/s][A

5600: train: 0.0574, val: 0.0582



 57%|█████▋    | 5682/10000 [00:20<00:15, 272.23it/s][A
 57%|█████▋    | 5710/10000 [00:21<00:15, 272.36it/s][A
 57%|█████▋    | 5738/10000 [00:21<00:15, 270.93it/s][A

5700: train: 0.0574, val: 0.0582



 58%|█████▊    | 5766/10000 [00:21<00:15, 269.55it/s][A
 58%|█████▊    | 5794/10000 [00:21<00:15, 271.17it/s][A
 58%|█████▊    | 5822/10000 [00:21<00:15, 273.71it/s][A
 58%|█████▊    | 5850/10000 [00:21<00:15, 272.01it/s][A

5800: train: 0.0574, val: 0.0582



 59%|█████▉    | 5878/10000 [00:21<00:15, 272.88it/s][A
 59%|█████▉    | 5906/10000 [00:21<00:15, 270.55it/s][A
 59%|█████▉    | 5934/10000 [00:21<00:15, 269.02it/s][A

5900: train: 0.0574, val: 0.0582



 60%|█████▉    | 5961/10000 [00:21<00:15, 267.12it/s][A
 60%|█████▉    | 5988/10000 [00:22<00:15, 267.30it/s][A
 60%|██████    | 6016/10000 [00:22<00:14, 269.14it/s][A
 60%|██████    | 6043/10000 [00:22<00:14, 268.63it/s][A

6000: train: 0.0574, val: 0.0582



 61%|██████    | 6070/10000 [00:22<00:14, 269.03it/s][A
 61%|██████    | 6097/10000 [00:22<00:14, 268.93it/s][A
 61%|██████▏   | 6125/10000 [00:22<00:14, 269.67it/s][A
 62%|██████▏   | 6152/10000 [00:22<00:14, 267.80it/s][A

6100: train: 0.0574, val: 0.0582



 62%|██████▏   | 6179/10000 [00:22<00:14, 267.82it/s][A
 62%|██████▏   | 6206/10000 [00:22<00:14, 266.17it/s][A
 62%|██████▏   | 6233/10000 [00:22<00:14, 265.73it/s][A

6200: train: 0.0574, val: 0.0582



 63%|██████▎   | 6260/10000 [00:23<00:14, 265.59it/s][A
 63%|██████▎   | 6288/10000 [00:23<00:13, 269.22it/s][A
 63%|██████▎   | 6316/10000 [00:23<00:13, 270.41it/s][A
 63%|██████▎   | 6344/10000 [00:23<00:13, 268.98it/s][A

6300: train: 0.0574, val: 0.0582



 64%|██████▎   | 6371/10000 [00:23<00:13, 269.17it/s][A
 64%|██████▍   | 6398/10000 [00:23<00:13, 267.58it/s][A
 64%|██████▍   | 6425/10000 [00:23<00:13, 267.69it/s][A
 65%|██████▍   | 6452/10000 [00:23<00:13, 264.80it/s][A

6400: train: 0.0574, val: 0.0582



 65%|██████▍   | 6479/10000 [00:23<00:13, 262.34it/s][A
 65%|██████▌   | 6507/10000 [00:24<00:13, 265.36it/s][A
 65%|██████▌   | 6535/10000 [00:24<00:12, 267.09it/s][A

6500: train: 0.0574, val: 0.0582



 66%|██████▌   | 6562/10000 [00:24<00:12, 267.66it/s][A
 66%|██████▌   | 6590/10000 [00:24<00:12, 269.40it/s][A
 66%|██████▌   | 6618/10000 [00:24<00:12, 270.12it/s][A
 66%|██████▋   | 6646/10000 [00:24<00:12, 268.48it/s][A

6600: train: 0.0574, val: 0.0582



 67%|██████▋   | 6673/10000 [00:24<00:12, 266.30it/s][A
 67%|██████▋   | 6700/10000 [00:24<00:12, 266.59it/s][A
 67%|██████▋   | 6728/10000 [00:24<00:12, 269.23it/s][A
 68%|██████▊   | 6756/10000 [00:24<00:11, 272.23it/s][A

6700: train: 0.0574, val: 0.0582



 68%|██████▊   | 6784/10000 [00:25<00:11, 271.09it/s][A
 68%|██████▊   | 6812/10000 [00:25<00:11, 269.00it/s][A
 68%|██████▊   | 6839/10000 [00:25<00:11, 267.16it/s][A

6800: train: 0.0574, val: 0.0582



 69%|██████▊   | 6866/10000 [00:25<00:11, 264.24it/s][A
 69%|██████▉   | 6893/10000 [00:25<00:11, 264.98it/s][A
 69%|██████▉   | 6920/10000 [00:25<00:11, 264.69it/s][A
 69%|██████▉   | 6947/10000 [00:25<00:11, 263.30it/s][A

6900: train: 0.0574, val: 0.0582



 70%|██████▉   | 6974/10000 [00:25<00:11, 264.81it/s][A
 70%|███████   | 7001/10000 [00:25<00:11, 266.33it/s][A
 70%|███████   | 7028/10000 [00:25<00:11, 266.87it/s][A

7000: train: 0.0574, val: 0.0582



 71%|███████   | 7056/10000 [00:26<00:10, 268.02it/s][A
 71%|███████   | 7084/10000 [00:26<00:10, 268.01it/s][A
 71%|███████   | 7111/10000 [00:26<00:10, 266.59it/s][A
 71%|███████▏  | 7138/10000 [00:26<00:10, 265.62it/s][A

7100: train: 0.0574, val: 0.0582



 72%|███████▏  | 7165/10000 [00:26<00:10, 263.73it/s][A
 72%|███████▏  | 7192/10000 [00:26<00:10, 263.25it/s][A
 72%|███████▏  | 7219/10000 [00:26<00:10, 260.93it/s][A
 72%|███████▏  | 7246/10000 [00:26<00:10, 261.70it/s][A

7200: train: 0.0574, val: 0.0582



 73%|███████▎  | 7274/10000 [00:26<00:10, 265.86it/s][A
 73%|███████▎  | 7301/10000 [00:26<00:10, 266.43it/s][A
 73%|███████▎  | 7329/10000 [00:27<00:09, 269.06it/s][A

7300: train: 0.0574, val: 0.0582



 74%|███████▎  | 7356/10000 [00:27<00:09, 268.76it/s][A
 74%|███████▍  | 7383/10000 [00:27<00:09, 265.01it/s][A
 74%|███████▍  | 7410/10000 [00:27<00:09, 262.54it/s][A
 74%|███████▍  | 7437/10000 [00:27<00:09, 261.68it/s][A

7400: train: 0.0574, val: 0.0582



 75%|███████▍  | 7464/10000 [00:27<00:09, 260.40it/s][A
 75%|███████▍  | 7491/10000 [00:27<00:09, 259.82it/s][A
 75%|███████▌  | 7518/10000 [00:27<00:09, 260.55it/s][A
 75%|███████▌  | 7545/10000 [00:27<00:09, 259.31it/s][A

7500: train: 0.0574, val: 0.0582



 76%|███████▌  | 7572/10000 [00:28<00:09, 260.40it/s][A
 76%|███████▌  | 7600/10000 [00:28<00:09, 264.65it/s][A
 76%|███████▋  | 7628/10000 [00:28<00:08, 266.34it/s][A
 77%|███████▋  | 7655/10000 [00:28<00:08, 267.25it/s][A

7600: train: 0.0574, val: 0.0582



 77%|███████▋  | 7683/10000 [00:28<00:08, 269.43it/s][A
 77%|███████▋  | 7710/10000 [00:28<00:08, 267.50it/s][A
 77%|███████▋  | 7737/10000 [00:28<00:08, 265.30it/s][A

7700: train: 0.0574, val: 0.0582



 78%|███████▊  | 7764/10000 [00:28<00:08, 264.79it/s][A
 78%|███████▊  | 7791/10000 [00:28<00:08, 259.12it/s][A
 78%|███████▊  | 7818/10000 [00:28<00:08, 259.77it/s][A
 78%|███████▊  | 7845/10000 [00:29<00:08, 260.28it/s][A

7800: train: 0.0574, val: 0.0582



 79%|███████▊  | 7872/10000 [00:29<00:08, 259.03it/s][A
 79%|███████▉  | 7898/10000 [00:29<00:08, 256.97it/s][A
 79%|███████▉  | 7924/10000 [00:29<00:08, 255.07it/s][A
 80%|███████▉  | 7951/10000 [00:29<00:07, 258.88it/s][A

7900: train: 0.0574, val: 0.0582



 80%|███████▉  | 7979/10000 [00:29<00:07, 263.42it/s][A
 80%|████████  | 8006/10000 [00:29<00:07, 265.20it/s][A
 80%|████████  | 8034/10000 [00:29<00:07, 267.68it/s][A

8000: train: 0.0574, val: 0.0582



 81%|████████  | 8061/10000 [00:29<00:07, 267.48it/s][A
 81%|████████  | 8088/10000 [00:29<00:07, 266.25it/s][A
 81%|████████  | 8115/10000 [00:30<00:07, 265.32it/s][A
 81%|████████▏ | 8142/10000 [00:30<00:07, 264.34it/s][A

8100: train: 0.0573, val: 0.0582



 82%|████████▏ | 8169/10000 [00:30<00:06, 261.58it/s][A
 82%|████████▏ | 8196/10000 [00:30<00:06, 261.12it/s][A
 82%|████████▏ | 8223/10000 [00:30<00:06, 260.85it/s][A
 82%|████████▎ | 8250/10000 [00:30<00:06, 263.36it/s][A

8200: train: 0.0574, val: 0.0582



 83%|████████▎ | 8278/10000 [00:30<00:06, 266.94it/s][A
 83%|████████▎ | 8306/10000 [00:30<00:06, 268.53it/s][A
 83%|████████▎ | 8333/10000 [00:30<00:06, 267.45it/s][A

8300: train: 0.0574, val: 0.0582



 84%|████████▎ | 8360/10000 [00:30<00:06, 268.17it/s][A
 84%|████████▍ | 8387/10000 [00:31<00:06, 266.53it/s][A
 84%|████████▍ | 8414/10000 [00:31<00:06, 262.79it/s][A
 84%|████████▍ | 8442/10000 [00:31<00:05, 265.49it/s][A

8400: train: 0.0573, val: 0.0582



 85%|████████▍ | 8469/10000 [00:31<00:05, 263.88it/s][A
 85%|████████▍ | 8496/10000 [00:31<00:05, 263.33it/s][A
 85%|████████▌ | 8523/10000 [00:31<00:05, 262.19it/s][A
 86%|████████▌ | 8550/10000 [00:31<00:05, 262.79it/s][A

8500: train: 0.0573, val: 0.0582



 86%|████████▌ | 8577/10000 [00:31<00:05, 258.48it/s][A
 86%|████████▌ | 8604/10000 [00:31<00:05, 260.55it/s][A
 86%|████████▋ | 8632/10000 [00:32<00:05, 264.43it/s][A

8600: train: 0.0573, val: 0.0582



 87%|████████▋ | 8660/10000 [00:32<00:05, 266.72it/s][A
 87%|████████▋ | 8687/10000 [00:32<00:04, 266.03it/s][A
 87%|████████▋ | 8714/10000 [00:32<00:04, 266.78it/s][A
 87%|████████▋ | 8741/10000 [00:32<00:04, 265.87it/s][A

8700: train: 0.0573, val: 0.0582



 88%|████████▊ | 8768/10000 [00:32<00:04, 264.28it/s][A
 88%|████████▊ | 8795/10000 [00:32<00:04, 262.21it/s][A
 88%|████████▊ | 8822/10000 [00:32<00:04, 260.49it/s][A
 88%|████████▊ | 8849/10000 [00:32<00:04, 259.02it/s][A

8800: train: 0.0573, val: 0.0582



 89%|████████▉ | 8877/10000 [00:32<00:04, 262.97it/s][A
 89%|████████▉ | 8905/10000 [00:33<00:04, 266.19it/s][A
 89%|████████▉ | 8932/10000 [00:33<00:04, 265.61it/s][A

8900: train: 0.0573, val: 0.0582



 90%|████████▉ | 8959/10000 [00:33<00:03, 265.33it/s][A
 90%|████████▉ | 8986/10000 [00:33<00:03, 264.91it/s][A
 90%|█████████ | 9013/10000 [00:33<00:03, 261.35it/s][A
 90%|█████████ | 9040/10000 [00:33<00:03, 261.82it/s][A

9000: train: 0.0573, val: 0.0582



 91%|█████████ | 9067/10000 [00:33<00:03, 259.80it/s][A
 91%|█████████ | 9094/10000 [00:33<00:03, 261.25it/s][A
 91%|█████████ | 9121/10000 [00:33<00:03, 262.07it/s][A
 91%|█████████▏| 9149/10000 [00:33<00:03, 266.62it/s][A

9100: train: 0.0573, val: 0.0582



 92%|█████████▏| 9177/10000 [00:34<00:03, 268.00it/s][A
 92%|█████████▏| 9204/10000 [00:34<00:02, 267.46it/s][A
 92%|█████████▏| 9231/10000 [00:34<00:02, 266.48it/s][A

9200: train: 0.0573, val: 0.0582



 93%|█████████▎| 9258/10000 [00:34<00:02, 263.67it/s][A
 93%|█████████▎| 9285/10000 [00:34<00:02, 262.45it/s][A
 93%|█████████▎| 9312/10000 [00:34<00:02, 259.29it/s][A
 93%|█████████▎| 9338/10000 [00:34<00:02, 259.18it/s][A

9300: train: 0.0573, val: 0.0582



 94%|█████████▎| 9365/10000 [00:34<00:02, 260.58it/s][A
 94%|█████████▍| 9393/10000 [00:34<00:02, 264.74it/s][A
 94%|█████████▍| 9420/10000 [00:35<00:02, 265.18it/s][A
 94%|█████████▍| 9448/10000 [00:35<00:02, 266.70it/s][A

9400: train: 0.0573, val: 0.0582



 95%|█████████▍| 9475/10000 [00:35<00:01, 263.16it/s][A
 95%|█████████▌| 9502/10000 [00:35<00:01, 262.30it/s][A
 95%|█████████▌| 9529/10000 [00:35<00:01, 261.74it/s][A

9500: train: 0.0573, val: 0.0582



 96%|█████████▌| 9556/10000 [00:35<00:01, 260.15it/s][A
 96%|█████████▌| 9583/10000 [00:35<00:01, 259.91it/s][A
 96%|█████████▌| 9609/10000 [00:35<00:01, 259.50it/s][A
 96%|█████████▋| 9635/10000 [00:35<00:01, 259.17it/s][A

9600: train: 0.0573, val: 0.0582



 97%|█████████▋| 9661/10000 [00:35<00:01, 258.38it/s][A
 97%|█████████▋| 9687/10000 [00:36<00:01, 258.54it/s][A
 97%|█████████▋| 9713/10000 [00:36<00:01, 257.05it/s][A
 97%|█████████▋| 9739/10000 [00:36<00:01, 256.95it/s][A

9700: train: 0.0573, val: 0.0582



 98%|█████████▊| 9765/10000 [00:36<00:00, 257.32it/s][A
 98%|█████████▊| 9791/10000 [00:36<00:00, 257.54it/s][A
 98%|█████████▊| 9817/10000 [00:36<00:00, 257.51it/s][A
 98%|█████████▊| 9843/10000 [00:36<00:00, 257.85it/s][A

9800: train: 0.0573, val: 0.0582



 99%|█████████▊| 9869/10000 [00:36<00:00, 258.11it/s][A
 99%|█████████▉| 9895/10000 [00:36<00:00, 257.79it/s][A
 99%|█████████▉| 9921/10000 [00:36<00:00, 257.27it/s][A
 99%|█████████▉| 9947/10000 [00:37<00:00, 257.21it/s][A

9900: train: 0.0573, val: 0.0582



100%|█████████▉| 9973/10000 [00:37<00:00, 256.98it/s][A
100%|█████████▉| 9999/10000 [00:37<00:00, 257.56it/s][A
100%|██████████| 10000/10000 [00:37<00:00, 268.34it/s][A

In [9]:
torch.abs(model.w1.weight) / torch.sum(torch.abs(model.w1.weight), 1, keepdim=True)

tensor([[1.6175e-05, 3.3427e-06, 7.2556e-06, 3.6481e-06, 3.5664e-06, 5.8942e-06,
         1.0334e-02, 7.9139e-06, 2.6752e-06, 4.0633e-06, 6.9975e-06, 5.5327e-06,
         7.7679e-03, 1.2487e-05, 3.5652e-06, 1.8475e-06, 5.6836e-06, 7.4572e-06,
         6.7314e-06, 2.8986e-07, 6.8215e-06, 2.6493e-06, 6.9987e-06, 3.5359e-06,
         1.1278e-05, 7.9612e-06, 5.5704e-06, 6.9655e-07, 9.4248e-06, 8.0391e-06,
         7.0075e-02, 5.2306e-06, 7.8575e-06, 3.7447e-06, 3.6100e-06, 2.4426e-06,
         8.0505e-02, 9.6311e-07, 9.8790e-06, 5.6635e-06, 1.0270e-05, 2.7731e-06,
         1.1803e-01, 1.0128e-05, 3.4024e-06, 7.1866e-06, 6.4559e-06, 5.4386e-06,
         3.5133e-02, 9.3689e-06, 9.1701e-06, 4.9436e-06, 4.0378e-06, 5.7272e-06,
         1.2139e-05, 6.4111e-06, 3.2622e-07, 7.0553e-06, 7.3319e-06, 3.8639e-06,
         2.4847e-02, 2.6180e-06, 3.7442e-06, 3.7545e-06, 3.4401e-06, 7.0550e-06,
         4.1661e-02, 7.4850e-06, 1.6732e-06, 4.0225e-06, 7.1409e-06, 2.6647e-06,
         4.1517e-02, 5.8188e

Apply the trained model on the L1 test predictions:

TODO: implement apply to test and call generate_submission with clip_eps 1e-5