In [None]:
#from multiprocessing import set_start_method
#set_start_method("spawn")

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2"

In [None]:
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.data.all import *
from fastai2.data.core import *
from fastai2.distributed import *
from fastai2.data.transforms import *
from fastai2.vision.all import *
import gc
from itertools import product
from scipy import signal
import seaborn as sns

In [None]:
warnings.simplefilter('ignore')
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', 1000)
pd.set_option('display.max_rows', 500)

from pylab import rcParams
rcParams['figure.figsize'] = 20, 5
rcParams['figure.dpi'] = 300
rcParams['agg.path.chunksize'] = 10000

In [None]:
SEGMENT_SIZE      = 500_000 
TEST_SEGMENT_SIZE = 100_000

WINDOW_SIZE = 1000
BS = max(1,torch.cuda.device_count()) * 192
HIST_BINS = 128
SPLITS = 5
XTRA_DS = False

FEAT_WINDOW = 1

assert SEGMENT_SIZE % WINDOW_SIZE == 0
assert (SEGMENT_SIZE // WINDOW_SIZE) % SPLITS == 0
SEED = 321
DATA_SUFFIX = '_clean'

p_input = Path('input')

# Read data

In [None]:
# read data
train_dtypes = {'time': np.float32, 'signal': np.float32, 'open_channels': np.int32 }
test_dtypes  = {'time': np.float32, 'signal': np.float32 }
df_train  = pd.read_csv(p_input / f'train.csv', dtype= train_dtypes)
df_test   = pd.read_csv(p_input / f'test.csv',  dtype= test_dtypes)
df_train_drift = pd.read_csv(p_input / f'train{DATA_SUFFIX}.csv', dtype= train_dtypes)
df_test_drift  = pd.read_csv(p_input / f'test{DATA_SUFFIX}.csv',  dtype= test_dtypes)
sub   = pd.read_csv(p_input / 'sample_submission.csv',  dtype={'time': np.float32})
df_train['drift'] = df_train['signal'] - df_train_drift['signal']
df_test['drift']  = df_test['signal']  - df_test_drift['signal']

In [None]:
df_train['signal'] =  df_train_drift['signal']
df_test['signal']  =   df_test_drift['signal']

In [None]:
#df_train['open_channels'][2300000:2400000][(df_train['open_channels'][2300000:2400000]==0)]=1

In [None]:
d_xtra_csvs = {
    1: ['outfinaltest10.csv',  'outfinaltest44.csv',],#  'outfinaltest78.csv',],  'outfinaltest10.csv',  'outfinaltest44.csv'],
    3: ['outfinaltest1.csv',   'outfinaltest2.csv',   'outfinaltest3.csv',   'outfinaltest4.csv', 'outfinaltest5.csv'],
    5: ['outfinaltest328.csv', 'outfinaltest534.csv', 'outfinaltest747.csv',]#, 'outfinaltest328.csv', 'outfinaltest534.csv']
}

df_train_xtra = None
for _,xtra_csvs in d_xtra_csvs.items():
    for xtra_csv in xtra_csvs:
        xx = pd.read_csv(p_input / xtra_csv , header=None,names=['time', 'signal', 'open_channels'])
        df_train_xtra = pd.concat((xx,df_train_xtra), axis=0)
df_train_xtra['drift']  = 0.
if XTRA_DS: df_train = pd.concat((df_train,df_train_xtra), axis=0)

In [None]:
train = torch.cat((torch.FloatTensor(df_train['signal'        ].values).unsqueeze(0),
                   #torch.FloatTensor(df_train['drift'         ].values).unsqueeze(0),
                   torch.FloatTensor(df_train['open_channels' ].values).unsqueeze(0)))
test  = torch.cat((torch.FloatTensor(df_test ['signal'        ].values).unsqueeze(0),
                   #torch.FloatTensor(df_test ['drift'         ].values).unsqueeze(0)
                  ))

In [None]:
test_clean  = np.load(str(p_input / 'test_x_without50hz.npy'))
train_clean = np.load(str(p_input / 'train_x_without50hz.npy'))

In [None]:
plt.plot(test_clean.flatten())

In [None]:
plt.plot(train_clean.flatten())

In [None]:
plt.plot(train[0].flatten())

In [None]:
train[0,:] = Tensor(train_clean)
test[0,:]  = Tensor(test_clean)

# Synth

In [None]:
p_synth = Path('synth')
use_memmap = True
load_fn = np.load if not use_memmap else partial(np.lib.format.open_memmap, mode='r')

try:
    high = load_fn(str(p_synth / 'high.npy'))
    low  = load_fn(str(p_synth / 'low.npy'))
except:
    high   = pd.read_csv(p_synth / 'high.csv',header=None).values.astype('uint8')
    low    = pd.read_csv(p_synth /  'low.csv',header=None).values.astype('uint8')
    np.save(str(p_synth / 'high.npy'), high)
    np.save(str(p_synth /  'low.npy'),  low)
high = high.reshape(-1,SEGMENT_SIZE)
low  =  low.reshape(-1,SEGMENT_SIZE)

In [None]:
def get_synth_segment_y(states,max_channels,size=SEGMENT_SIZE):
    y = torch.zeros((size,),dtype=torch.uint8)
    states = {'l':low, 'h':high }[states]
    for ii,i in enumerate(np.random.choice(states.shape[0],max_channels,replace=False)):
        o = np.random.randint(1+max(0,SEGMENT_SIZE-size))
        y += states[i,o:o+size]
    return y.clamp(0, 10)
y=get_synth_segment_y('h',10,SEGMENT_SIZE)
np.bincount(y,minlength=11)
# (low,1), (low,1), (high,1), (high,3), (high,10), (high,5), (high,1), (high,3), (high,5), (high,10)

In [None]:
for i in range(10): print(np.bincount(get_synth_segment_y('l',5,SEGMENT_SIZE//5),minlength=11))

In [None]:
# -0.12838999,49.9232689,1.73201717

test_std_by_type = {('l', 3): [0.23856227099895477,
              0.2386934608221054,
              0.23847833275794983,
              0.23833979666233063,
              0.24214892089366913,
              0.24653472006320953,
              0.2441188395023346],
             ('h', 3): [0.2698204517364502, 0.2679280936717987],
             ('h', 5): [0.27920418977737427, 0.27736446261405945],
             ('l', 4): [0.23756568133831024,
              0.24620571732521057,
              0.24222536385059357],
             ('h', 1): [0.2443782091140747],
             ('h', 10): [0.3530166447162628, 0.3537397086620331],
             ('l', 2): [0.23785734176635742,
              0.2443142682313919,
              0.24280819296836853]}

train_std_by_type = {('l', 1): [0.24515989422798157, 0.24703997373580933],
             ('h', 1): [0.24486009776592255, 0.2447292059659958],
             ('h', 3): [0.265836238861084],
             ('h', 10): [0.4045635759830475, 0.40377894043922424],
             ('h', 5): [0.28642651438713074, 0.28378984332084656]}

def get_synth_segment_xy(states,max_channels,size=SEGMENT_SIZE,add_ac=False,y=None,add_noise=True,add_bias=True):
    kernel = tensor([-1.6590e-03, -1.1617e-04, -1.0344e-03,  8.4467e-04, -9.7054e-04,
              1.4413e-03,  6.5739e-03,  2.8979e-02,  1.2115e+00, -1.0717e-03,
             -3.8138e-03,  6.0101e-04,  1.2317e-04,  3.1660e-03, -8.8741e-04,
              3.2797e-04,  2.5820e-03, -2.3032e-03])
    dim_k = kernel.numel()
    if max_channels is None: max_channels = y.max()
    bias = tensor([-5.5336 if max_channels >= 10 else -2.7708])
    if y is None: y = get_synth_segment_y(states,max_channels,size)
    y_padded = torch.zeros(y.shape[0]+dim_k-1)
    cc = (dim_k-1)//2
    y_padded[cc:cc+y.shape[0]] = y.float()
    y_padded = y_padded.view(1,1,-1)
    x = (F.conv1d(y_padded,kernel.view(1,1,-1)).flatten() + bias).to(y.device)
    if add_noise:
        if add_ac:
            ac_weight,ac_freq,ac_phase = -0.12838999,50+0.2*(torch.rand(1)-0.5),2*math.pi*torch.rand(1)
            x_range = torch.arange(x.shape[0]).float().to(x.device)
            wave = (ac_weight * torch.sin( x_range * 2*math.pi/10000 * ac_freq + ac_phase))
            x += wave
        max_channels = min(max_channels,10)
        std_by_type = train_std_by_type if (states,max_channels) in train_std_by_type else test_std_by_type
        std = np.random.choice(std_by_type[(states,max_channels)])
        x += std * torch.randn(x.numel()).to(x.device)
        if add_bias: x += 0.4*(2*torch.rand(1)-1)
    return x.unsqueeze(-1),y.unsqueeze(-1)


In [None]:
train[-1].view(-1,SEGMENT_SIZE)[0,:]

from fastai2.text.models.qrnn import QRNN
class Synthetizer(nn.Module):
    def __init__(self):
        super().__init__()
        f_out = 16

        self.QRNN = QRNN(input_size=FEAT_WINDOW, hidden_size=f_out, n_layers=3, batch_first=True, 
                          window=2, bidirectional=True, dropout=0.)
        self.LSTM  = nn.LSTM(input_size=f_out*2, hidden_size=f_out, num_layers=1, batch_first=True, 
                            bidirectional=True, dropout=0.)
        self.fc = nn.Sequential(nn.Linear(f_out*2,  f_out),    Swish(), 
                                nn.Linear(f_out,    f_out//2), Swish(),
                                nn.Linear(f_out//2, 1))
        
        k,p = (3,1), (1,0)
        self.convs = nn.Sequential(
            ConvLayer( 1,32,k,padding=p), ConvLayer(32,32,k,padding=p), ConvLayer(32,32,k,padding=p),
            ConvLayer(32,32,k,padding=p), ConvLayer(32,32,k,padding=p), ConvLayer(32,f_out*2,k,padding=p))
        
    def forward(self, x):
        n_pad = 48
        x = nn.ReplicationPad1d(n_pad)(x.permute(0,2,1)).permute(0,2,1)#.unsqueeze(-1)
        x,_ = self.QRNN(x) # bs,16
        #hc = hc[-1].unsqueeze(-1).expand(-1,-1,x.shape[1])
        #x = self.convs(x.permute(0,2,1).unsqueeze(-1)).squeeze(-1)
        #x = torch.cat((hc,x),dim=1).permute(0,2,1)
#        x, _ = self.LSTM(x)
        return self.fc(x)[:,n_pad:-n_pad]
synth = Synthetizer()
p='models/y_to_x_qrnn7297_t36000_v9_BS40_SS500000_WS100000_FW1_cv0.089554_clean_clean50hz.pth'
synth.load_state_dict(torch.load(p)['model'])

segment = [4,9]
s = train.view(2,-1,SEGMENT_SIZE)[:,segment,:]
x,y=s[0].flatten(),s[-1].flatten()
print(x.shape)
x_pred_naive = get_synth_segment_xy(None,10,y=y,add_noise=False)[0].flatten()
synth.eval().cuda()
with torch.no_grad(): x_pred_nn = synth(y.view(1,-1,1).cuda()).flatten().cpu()
#plt.plot(x)
plt.plot(x_pred_naive)
plt.plot(x_pred_nn)
((x-x_pred_naive)**2).sum(),((x-x_pred_nn)**2).sum()

In [None]:
test_types = [('l', 3), ('h', 3), ('h', 5), ('l', 4), ('h', 1), ('h', 10), ('h', 5), ('h', 10), ('l', 3), ('h', 3),
              ('l', 3), ('l', 3), ('l', 2), ('l', 4), ('l', 4), ('l',  3), ('l', 3), ('l',  2), ('l', 2), ('l', 3)]

flatten = lambda l: [item for sublist in l for item in sublist]

test_types_x3 = flatten((i,i,i) for i in test_types)
public_types  = test_types_x3[:len(test_types)]
private_types = test_types_x3[len(test_types):]

public_s = np.nan * torch.empty(2,len(public_types),TEST_SEGMENT_SIZE,1)
for i,segment in enumerate(public_types): 
    public_s[0,i],public_s[-1,i]= get_synth_segment_xy(*segment,TEST_SEGMENT_SIZE)
plt.plot(public_s[0].flatten())

In [None]:
private_s = np.nan * torch.empty(2,len(private_types),TEST_SEGMENT_SIZE,1)
for i,segment in enumerate(private_types): 
    private_s[0,i],private_s[-1,i]= get_synth_segment_xy(*segment,TEST_SEGMENT_SIZE)
plt.plot(private_s[0].flatten())

In [None]:
# r = parallel(lambda x:get_synth_segment_xy(*x), private_types)

In [None]:
train_types=flatten([[('l', 1),('l', 2), ('l', 3), ('l', 4)],
             [('h',1)]*10,
             [('h', 3)]*20,
             [('h', 5)]*20,
             [('h', 10)]*60,]) * 10

train_types = [('l', 1),('l', 1),('h', 1),('h', 3),('h',10),('h', 5),('h', 1),('h', 3),('h', 5),('h',10),
               ('h',10),('h',10),('h',10),('h',11),('h',11),('h',12),('h',13),('h',13),('h',14),('h',14),
               ('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),('l', 3),
               ('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),('h', 5),
               ('h', 3),('h', 3),('h', 3),('h', 3),('h', 3),('h', 3),('h', 3),('h', 3),('h', 3),('h', 3)] * 10

train_types = public_types * 500

try:
    train_s = torch.load("train_s")
    assert train_s.shape==(2,len(train_types),SEGMENT_SIZE,1)
except:
    train_s = np.nan * torch.empty(2,len(train_types),SEGMENT_SIZE,1)
    for i,segment in progress_bar(enumerate(train_types),total=len(train_types)):
        train_s[0,i],train_s[-1,i] = get_synth_segment_xy(*segment)
    torch.save(train_s,"train_s")

#np.bincount(train_s[2,...].flatten(),minlength=11)

# Filter 

def filter(x):
    return mne.filter.notch_filter(x.numpy().astype('float64'),10000,50.0)
train[0,:]=Tensor(filter(train[0,:]))
test[0,:] =Tensor(filter( test[0,:]))

In [None]:
train.shape,train_s.shape,test.shape

In [None]:
train   =   train.view(  train.shape[0],-1,SEGMENT_SIZE,1)
test    =    test.view(   test.shape[0],-1,TEST_SEGMENT_SIZE,1)

In [None]:
train = train[:,[0,1,2,3,4,5,6,8,9],...]

# Normalize

In [None]:
signal_mean, signal_std = train[0].mean(),train[0].std()
signal_min = (min(train[0].min(), test[0].min())-0.4-signal_mean)/signal_std
signal_max = (max(train[0].max(), test[0].max())+0.4-signal_mean)/signal_std
signal_min, signal_max

# Dataset and splits

In [None]:
split = 0
split_size = SEGMENT_SIZE//WINDOW_SIZE//SPLITS
windows_per_segment = np.arange(SEGMENT_SIZE//WINDOW_SIZE)
valid_split_idx = split*split_size + np.arange(split_size)
all_segments = range(train.shape[1])
valid_idx = list(product(range(50) if XTRA_DS else all_segments,valid_split_idx))
train_idx = list(product(all_segments,windows_per_segment))
train_idx = list(sorted(set(train_idx).difference(set(valid_idx))))

In [None]:
windows_per_segment,split_size

In [None]:
class IonDataset(torch.utils.data.Dataset):
    def __init__(self, data,idx=None,jitter=False,shift=0):
        super().__init__()
        self.data, self.jitter,self.shift = data, jitter, shift
        self.segment_size = data.shape[-2]
        self.idx = ifnone(idx,list(product(range(self.data.shape[1]),np.arange(self.segment_size//WINDOW_SIZE))))
        self.n_inp = 1
        self.has_y = self.data.shape[0] == 2
        self.idx_set = set(self.idx)
        self.histc = {}
        for s in range(self.data[0].shape[0]):
            x = self.data[0,s].cuda()
            #x_max, x_min = x.max(),x.min()
            histc = torch.histc(x,bins=HIST_BINS,min=signal_min,max=signal_max)
            histc /= histc.max()
            #histc = torch.empty((HIST_BINS,))
            #histc[2:] = torch.histc(x,bins=HIST_BINS-2,min=x_min,max=x_max)
            #histc /= histc[2:].max()
            #histc[0],histc[1] =  x_max, x_min
            self.histc[s] = histc.cpu()
            del x
    def __len__(self): return len(self.idx)
    def __getitem__(self, idx):
        s,o=self.idx[idx]
        jitter = 0
        if self.jitter:
            os,oe = 0,0
            if ((s,(o-1)) in self.idx_set): os = -WINDOW_SIZE//2
            if ((s,(o+1)) in self.idx_set): oe =  WINDOW_SIZE//2
            jitter = torch.randint(os,oe,(1,)).item()
        so,se = jitter+o*WINDOW_SIZE,jitter+(o+1)*WINDOW_SIZE
        assert (so < self.segment_size) and (se <= self.segment_size)
        so,se = so + self.shift, se + self.shift
        ss = (torch.arange(so,se) % self.segment_size) if (se > self.segment_size) else slice(so,se)
        x =  (self.data[0,s,ss,:], self.histc[s])
        if self.has_y: y_open_channels = self.data[-1,s:s+1,ss,0].long()
        return (x,y_open_channels) if self.has_y else (x,)

#train_ds = IonDataset((train, train_channels_in_segment), train_idx, jitter=False, synth=True)
len(train_ds)
x = torch.empty((len(train_ds),train_ds[0][0].shape[0]))
print(x.shape)
for i in range(len(train_ds)): x[i] = train_ds[i][0].squeeze()
plt.plot(x.flatten())

In [None]:
train_s_ds   = IonDataset(train_s,   jitter=True)
public_s_ds  = IonDataset(public_s)
private_s_ds = IonDataset(private_s)
train_ds     = IonDataset(train)
test_ds      = IonDataset(test)

In [None]:
train_s_ds[0]

In [None]:
class Normalize(Transform):
    parameters,order=L('mean', 'std'),99
    def __init__(self,mean,std): self.mean,self.std =mean,std
    def encodes(self,xy): return [((t[0]-self.mean) / self.std,t[1]) if i==0 else t for i,t in enumerate(xy)]

def make_ds(ds,shuffle=False,after_batch= Normalize(signal_mean,signal_std)):
           return DataLoader(ds, BS, shuffle=shuffle, num_workers=32, pin_memory=True, 
                             after_batch= after_batch)
            
train_s_dl   = make_ds(train_s_ds,True)
public_s_dl  = make_ds(public_s_ds)
private_s_dl = make_ds(private_s_ds)
train_dl     = make_ds(train_ds)
test_dl      = make_ds(test_ds)

dls = DataLoaders(train_s_dl, public_s_dl, test_dl, device=default_device())

In [None]:
tn = (train[0] - signal_mean) / signal_std
for b in range(tn.shape[0]):
    print(tn[b].min(), tn[b].max())

In [None]:
#next(iter(test_dl))

# Model 

In [None]:
from fastai2.text.models.qrnn import QRNN
class oldClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        f_out = 32

        self.RNN0 = QRNN(input_size=FEAT_WINDOW, hidden_size=f_out, n_layers=1, batch_first=True, 
                            bidirectional=True, dropout=0.)
        self.RNN1  = QRNN(input_size=f_out*2, hidden_size=f_out, n_layers=2, batch_first=True, 
                            bidirectional=True, dropout=0.)
        self.fc = nn.Sequential(nn.Linear(HIST_BINS+f_out*2, f_out),    Swish(),
                              nn.Linear(f_out,   f_out//2), Swish(),
                              nn.Linear(f_out//2,11))

    def forward(self, x):
        x,hist = x
        hist = hist.unsqueeze(1).expand(-1,x.shape[1],-1)
        n_pad = 64
        x = nn.ReflectionPad1d(n_pad)(x.permute(0,2,1)).permute(0,2,1)
        x, _ = self.RNN0(x)
        x, _ = self.RNN1(x)
        x = x[:,n_pad:-n_pad,:]
        return self.fc(torch.cat((x,hist),dim=-1)),x

    
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        f = 256
        k = 5
        p = (k-1)//2

        self.conv = nn.Sequential(
            nn.Conv1d(1,f,k,padding=p,padding_mode='replicate'), Swish(),
            nn.Conv1d(f,f,k,padding=p,padding_mode='replicate'), Swish(),
            nn.Conv1d(f,f,k,padding=p,padding_mode='replicate'), Swish(),
            nn.Conv1d(f,f,k,padding=p,padding_mode='replicate'), Swish(),
            nn.Conv1d(f,f,k,padding=p,padding_mode='replicate'), Swish(),
            nn.Conv1d(f,f,k,padding=p,padding_mode='replicate'), Swish())
        
        self.hist = nn.Sequential(
            nn.Linear(HIST_BINS   , HIST_BINS*4), Swish(),
            nn.Linear(HIST_BINS*4 , HIST_BINS*4), Swish(),
            nn.Linear(HIST_BINS*4 , f)          , Swish(),            
        )
        
        self.lin  = nn.Sequential(
            nn.Conv1d(2*f ,2*f ,1), Swish(),
            nn.Conv1d(2*f ,2*f, 1) ,Swish(),
            nn.Conv1d(2*f ,11  ,1))
        
    def forward(self, x):
        x, hist = x 
        hist = self.hist(hist) # B HIST_BINS -> B f_hist
        hist = hist.unsqueeze(-1).expand(-1,-1,x.shape[1]) # -> B f_hist WINDOW_SIZE
        x  = x.view(x.shape[0],1,-1) # B 1 WINDOW_SIZE 
        x = self.conv(x) # -> B f WINDOW_SIZE
        x = torch.cat((x,hist),dim=1) # -> B (f+f_hist) WINDOW_SIZE
        return self.lin(x).permute(0,2,1),x

model = ReformerLM(
    num_tokens = 11,
    dim = dim,
    depth = depth,
    max_seq_len = WINDOW_SIZE,
    heads = heads,
    lsh_dropout = lsh_dropout,
    bucket_size=bucket_size,
    causal = False,
    use_full_attn = False,
    fixed_position_emb = False,
    n_hashes = 4,
)
model.token_emb = nn.Linear(1,dim)

In [None]:
model = Classifier()
model

# Loss 

In [None]:
def softf1_loss(logits,true,weights=None,label_smoothing=0.):
    # 96 4000 11, 96 4000 1 
    n_classes = logits.shape[-1]
    weights = ifnone(weights,torch.ones((n_classes,),dtype=logits.dtype,device=logits.device))
    y_pred = logits.view(-1,n_classes).softmax(dim=-1)             
    y_true = F.one_hot(true.flatten(), n_classes).float()
    if label_smoothing > 0: y_true = y_true *(1-label_smoothing) + label_smoothing/n_classes

    tp = (y_true * y_pred).sum(dim=0).float()
    tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).float()
    fp = ((1 - y_true) * y_pred).sum(dim=0).float()
    fn = (y_true * (1 - y_pred)).sum(dim=0).float()

    precision = tp / (tp + fp )
    recall    = tp / (tp + fn )

    f1 = 2* (precision*recall) / (precision + recall )
    #f1 = f1.clamp(0,1) * weights
    f1 = f1[~torch.isnan(f1)].mean()
    return 1-f1

class SoftF1Loss(Module):
    def __init__(self, label_smoothing=0,weight=None): self.weight,self.label_smoothing = weight,label_smoothing
    def forward(self, output, target): return softf1_loss(output, target, self.weight,self.label_smoothing)

class SmartLabelSmoothingCE(Module):
    def __init__(self, label_smoothing:float=0.0): 
        self.label_smoothing = Tensor([
            [0,1,0,0,0,0,0,0,0,0,0],
            [1,0,1,0,0,0,0,0,0,0,0],
            [0,1,0,1,0,0,0,0,0,0,0],
            [0,0,1,0,1,0,0,0,0,0,0],
            [0,0,0,1,0,1,0,0,0,0,0],
            [0,0,0,0,1,0,1,0,0,0,0],
            [0,0,0,0,0,1,0,1,0,0,0],
            [0,0,0,0,0,0,1,0,1,0,0],
            [0,0,0,0,0,0,0,1,0,1,0],
            [0,0,0,0,0,0,0,0,1,0,1],
            [0,0,0,0,0,0,0,0,0,1,0],
        ])
        self.label_smoothing  *= (label_smoothing / self.label_smoothing.sum(dim=1,keepdims=True))
        self.label_smoothing += (1-label_smoothing) * torch.eye(11)
        
    def forward(self, logits, true):
        n_classes = logits.size()[-1]
        y_pred = logits.view(-1,n_classes)    
        y_true = self.label_smoothing[true.flatten()].view(-1,n_classes).to(logits.device)
        return (- y_true * F.log_softmax(y_pred, dim=1)).sum(dim=1).mean()

class LabelSmoothingCE(Module):
    def __init__(self, eps:float=0.65, reduction='mean'): self.eps,self.reduction = eps,reduction

    def forward(self, output, target):
        c = output.size()[-1]
        output = output.permute(0,2,1) # => B C S
        target = target.squeeze(1)     # => B S
        log_preds = F.log_softmax(output, dim=1)
        if self.reduction=='sum': loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=1)
            if self.reduction=='mean':  loss = loss.mean()
        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)

class DriftChannelsLoss(Module):
    def __init__(self, losses, weights=None):
        self.losses, self.weights = losses, ifnone(weights, [1.] * len(losses))
    def __call__(self, input:Tensor, target:Tensor, **kwargs):
        i_open_channels, _ = input
        t_open_channels    = target
        loss = L([l(i_open_channels,t_open_channels)*w for l,w in zip(self.losses, self.weights)]).sum()
        return loss
    
sls = SmartLabelSmoothingCE(0.1)
sls(Tensor([[[0,0,0,0,0,0,0,0,0,0,10]]]),LongTensor([[[10]]]))
SoftF1Loss()(Tensor([[[0,0,0,0,0,0,0,0,0,0,10]]]),LongTensor([[[10]]]))

# Metrics

In [None]:
import sklearn.metrics as skm

# Cell
class OpenChannelsAccumMetric(Metric):
    "Stores predictions and targets on CPU in accumulate to perform final calculations with `func`."
    def __init__(self, func, dim_argmax=None, sigmoid=False, thresh=None, to_np=False, invert_arg=False,
                 flatten=True, metric_name=None, **kwargs):
        store_attr(self,'func,dim_argmax,sigmoid,thresh,flatten,metric_name')
        self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs

    def reset(self): self.targs,self.preds = [],[]

    def accumulate(self, learn):
        t,p = learn.y,learn.pred[0] #learn.y[1],learn.pred[1]
        pred = p.argmax(dim=self.dim_argmax) if self.dim_argmax else p
        if self.sigmoid: pred = torch.sigmoid(pred)
        if self.thresh:  pred = (pred >= self.thresh)
        #pred = p.round()
        targ = t
        pred,targ = to_detach(pred),to_detach(targ)
        if self.flatten: pred,targ = flatten_check(pred,targ)
        self.preds.append(pred)
        self.targs.append(targ)

    @property
    def value(self):
        if len(self.preds) == 0: return
        preds,targs = torch.cat(self.preds),torch.cat(self.targs)
        if self.to_np: preds,targs = preds.numpy(),targs.numpy()
        return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)

    @property
    def name(self):
        return ifnone(self.metric_name,self.func.func.__name__ if hasattr(self.func, 'func') else  self.func.__name__)

# Cell
def skm_to__open_channels_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, **kwargs):
    "Convert `func` from sklearn.metrics to a fastai metric"
    dim_argmax = axis if is_class and thresh is None else None
    sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)
    return OpenChannelsAccumMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, thresh=thresh,
                       to_np=True, invert_arg=True, **kwargs)

def MF1Score(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None, **kwargs):
    "F1 score for single-label classification problems"
    return skm_to__open_channels_fastai(skm.f1_score, axis=axis,
                         labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, **kwargs)

def A(inp, targ, axis=-1):
    "Compute accuracy with `targ` when `pred` is bs * n_classes"
    pred,targ = inp[0],targ#inp[1], targ[1]
    pred,targ = flatten_check(pred.argmax(dim=axis), targ)
    return (pred == targ).float().mean()

In [None]:
mets = [MF1Score(labels=[l],average='macro', metric_name=f"f1_{l}") for l in range(11)]
mets.extend([MF1Score(labels=list(range(11)),average='macro', metric_name=f"f1"), A])

# Train

In [None]:
learn = None
gc.collect()
torch.cuda.empty_cache()
learn = Learner(dls,model,loss_func=DriftChannelsLoss([CrossEntropyLossFlat()]),metrics=mets, moms=None,)

#learn.callbacks.extend([F1Metric(learn)])

learn.to_parallel().to_fp16()
summary = learn.summary()
match = re.search(r'Total trainable params: ([0-9,]+)', summary)
model_params = int(match.group(1).replace(",",""))
print(summary)

In [None]:
modelname = 'lstm6843915_t62500_v11250_BS384_SS500000_WS400_FW1_cv0.9424_clean_synth'
try:
    learn.load(modelname, strict=True)
    print(f"Loaded {modelname}")
except:
    print(f"Failed to load {modelname}")

In [None]:
lr_min, lr_steep=learn.lr_find(end_lr=1e-1)

In [None]:
learn.loss_func=DriftChannelsLoss([CrossEntropyLossFlat()],[1.])
learn.fit_one_cycle(1,lr_max=1e-2,moms=(0.95, 0.85, 0.95),pct_start=0.25)

In [None]:
learn.loss_func=DriftChannelsLoss([ SoftF1Loss()])
learn.fit_flat_cos(1,lr=1e-3,pct_start=0.5)

In [None]:
learn.loss_func=DriftChannelsLoss([ SoftF1Loss()])
learn.fit_flat_cos(1,lr=1e-4,pct_start=0.5)

In [None]:
learn.recorder.plot_loss()

In [None]:
learn.loss_func=DriftChannelsLoss([SoftF1Loss(label_smoothing=0.),LabelSmoothingCE()],[40.,1])
learn.fit_flat_cos(1,1e-2,pct_start=0.25)

In [None]:
learn.fit_flat_cos(1,5e-4,pct_start=0.5)

In [None]:
learn.loss_func=DriftChannelsLoss([LabelSmoothingCE(0.75)])

In [None]:
learn.fit_flat_cos(20,1e-3,pct_start=0.1)

In [None]:
learn.summary()

In [None]:
public_s_dl.device=default_device()
p = learn.get_preds(dl=public_s_dl)
y_pred = p[0][0].argmax(dim=-1).flatten()
y_true = p[1].flatten()
public_cv=skm.f1_score(y_true,y_pred,labels=range(11),average='macro')
public_cv

In [None]:
private_s_dl.device=default_device()
p = learn.get_preds(dl=private_s_dl)
y_pred = p[0][0].argmax(dim=-1).flatten()
y_true = p[1].flatten()
private_cv=skm.f1_score(y_true,y_pred,labels=range(11),average='macro')
private_cv

In [None]:
train_dl.device=default_device()
p = learn.get_preds(dl=train_dl)
y_pred = p[0][0].argmax(dim=-1).flatten()
y_true = p[1].flatten()
train_cv=skm.f1_score(y_true,y_pred,labels=range(11),average='macro')
train_cv

In [None]:
cv,_,time = learn.recorder.log[-3:];cv,_,time

In [None]:
suffix =  '_synth_clean50hz'

In [None]:
modelname = f'conv_bn_swift_{model_params}_t{len(train_s_ds)}_v{len(public_s_ds)}_BS{BS}_SS{SEGMENT_SIZE}_WS{WINDOW_SIZE}_pucv{public_cv:0.06f}_prcv{private_cv:0.06f}_trcv{train_cv:0.06f}{DATA_SUFFIX}{suffix}'
learn.save(modelname);modelname

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
train_preds, valid_preds = learn.get_preds(0), learn.get_preds(1)
train_preds = train_preds[0][1],train_preds[1]
valid_preds = valid_preds[0][1],valid_preds[1]

In [None]:
d_feats = train_preds[0].shape[-1]

In [None]:
x0 = np.hstack((learn.model.fc._parameters['weight'].t().cpu().detach().numpy().flatten(),
                learn.model.fc._parameters['bias'].cpu().detach().numpy().flatten()))
x0.shape

In [None]:
x0.shape

In [None]:
#x,y = train_preds[0].view(-1,d_feats).cuda(),train_preds[1].view(-1).cuda()
x = y = None
gc.collect()
torch.cuda.empty_cache()
#x,y = valid_preds[0].view(-1,d_feats).cuda(),valid_preds[1].view(-1).cuda()
x,y = train_preds[0].view(-1,d_feats).cuda(),train_preds[1].view(-1).cuda()


In [None]:
true   = y
evals = 0
max_evals = len(x0) 
print(max_evals)
imb = master_bar(range(max_evals), total=max_evals)
def adjust_thresholds(thresholds):
    global evals,imb
    m = Tensor(thresholds[:11*d_feats]).view(d_feats,11).cuda()
    b = Tensor(thresholds[11*d_feats:]).view(1,11).cuda()
    preds = (x @ m + b).argmax(dim=-1)
    
    y_pred = F.one_hot(preds,11)                    # S, C
    y_true = F.one_hot(true, 11).to(torch.float32)  # S, C
    
    tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
    tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
    fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
    fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

    precision = tp / (tp + fp )
    recall = tp / (tp + fn )

    f1 = 2* (precision*recall) / (precision + recall)
    f1 = f1.mean().cpu().numpy()

    if evals % 1000 == 0: print(f'{100*evals/max_evals:0.02f}% {f1:0.06f}')#, thresholds)
    evals += 1
    return 1-f1
    
def callback(xk):
    print(evals)
    return False if evals > max_evals else True

res = scipy.optimize.minimize(adjust_thresholds, x0,method='Powell', 
                              options={'disp':True, 'maxfev' : max_evals },
                              callback= callback)

In [None]:
res.x

In [None]:
x0 = res.x

In [None]:
learn.model.fc._parameters['weight']

In [None]:
learn.model.fc._parameters['weight'].data = Tensor(res.x[:11*d_feats]).view(d_feats,11).t().cuda()
learn.model.fc._parameters['bias'].data   = Tensor(res.x[11*d_feats:]).cuda()

In [None]:
learn.model.fc._parameters['weight'],learn.model.fc._parameters['bias']

In [None]:
class F1Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        super().__init__()
        self.d = data[0].shape[-1]
        self.x,self.y = data[0].view(-1,self.d),data[1].view(-1)
        self.n_inp = 1
    def __len__(self): return len(self.x)
    def __getitem__(self, idx): return self.x[idx].unsqueeze(1),self.y[idx]#.unsqueeze(-1)
    
f1_train_ds,f1_valid_ds = F1Dataset(train_preds), F1Dataset(valid_preds)
f1_train_dl = DataLoader(f1_train_ds, len(f1_train_ds)//100, shuffle=True,   num_workers=8, pin_memory=True)
f1_valid_dl = DataLoader(f1_valid_ds, len(f1_train_ds)//100, shuffle=False,  num_workers=8, pin_memory=True)

In [None]:
d_feats = train_preds[0].shape[-1]

In [None]:
f1model = nn.Sequential(nn.Conv1d(d_feats   , d_feats//2,3,1,1,groups=1), nn.ReLU(), 
                        nn.Conv1d(d_feats//2, d_feats//4,3,1,1,groups=1), nn.ReLU(), 
                        nn.Conv1d(d_feats//4 ,        11,3,1,1,groups=1),
                        Flatten())

In [None]:
f1learn = None
gc.collect()
torch.cuda.empty_cache()

In [None]:
f1learn = None
gc.collect()
torch.cuda.empty_cache()
f1learn = Learner(DataLoaders(f1_train_dl,f1_valid_dl, device=default_device()),f1model,
                  loss_func=CrossEntropyLossFlat(),opt_func=Adam,
                  metrics=[F1Score(labels=list(range(11)),average='macro'), accuracy])
f1learn.summary()

In [None]:
f1learn.fit(1,1e-3)

In [None]:
f1learn.loss_func = softf1_loss

In [None]:
f1learn.fit_one_cycle(20,5e-3)

In [None]:
learn.model=learn.model.module

In [None]:
test_ds[0]

# Test

In [None]:
test_dl.dataset.data[0,0,:]

In [None]:
# shift 1 sample to the right!!! (we need to reverse it later!)
dd = test.flatten().clone() # torch.empty_like(test_dl.dataset.data.flatten())
dd[1:] = test_dl.dataset.data.flatten()[:-1]
dd[0]  = test_dl.dataset.data.flatten()[-1]
test_dl.dataset.data =  dd.view(test_dl.dataset.data.shape)
test_dl.dataset.data[0,0,:]

In [None]:
plt.plot(test_dl.dataset.data[0,0,:].flatten())


In [None]:
#learn.model = learn.model.module.module

In [None]:
learn.model

In [None]:
learn.model.eval()
test_preds  = torch.zeros(*test[0].squeeze().shape,11,dtype=torch.float)
test_preds_ = torch.empty_like(test_preds)
n_tta = 0
learn.model = nn.DataParallel(learn.model)
with torch.no_grad():
    mb = master_bar(np.linspace(0,WINDOW_SIZE,endpoint=False,num=5,dtype=np.int))
    for shift in mb:
        test_dl.dataset.shift = shift
        test_preds_ = test_preds_.view(-1,11)
        test_preds_[...] = 0.
        s = 0
        for xx in progress_bar(test_dl,parent=mb):
            x = xx[0]
            preds = learn.model((x[0].cuda(),x[1].cuda()))
            open_channels,_ = preds
            open_channels = open_channels.view(-1,11)
            l = open_channels.shape[0]
            test_preds_[s:s+l] += open_channels.cpu()
            s += l
        test_preds_ = test_preds_.view(-1,TEST_SEGMENT_SIZE,11)
        ss = torch.arange(0+shift,TEST_SEGMENT_SIZE+shift) % TEST_SEGMENT_SIZE
        for segment in range(test_preds_.shape[0]):
            test_preds[segment,ss] += test_preds_[segment,...]
        n_tta +=1
learn.model = learn.model.module

In [None]:
test_preds = test_preds.view(-1,11)
test_preds

In [None]:
open_channels = test_preds.argmax(dim=1)
open_channels.shape

In [None]:
plt.plot(test[0,:].flatten())
plt.plot(open_channels+10)

In [None]:
m_types = [('l',1),('l',2),('l',3),('l',4),('l',5),
           ('h',1),('h',2),('h',3),('h',4),('h',5),('h',10)]
p_dist = np.zeros((len(m_types),11))
for _ in range(100):
    for i,m_type in enumerate(m_types):
        p_dist[i] += np.bincount(get_synth_segment_y(*m_type,SEGMENT_SIZE),minlength=11)/(SEGMENT_SIZE)
p_dist /= 100
ts_densities = np.array(
    [np.bincount(open_channels.view(-1,TEST_SEGMENT_SIZE)[b],minlength=11)/TEST_SEGMENT_SIZE for b in range(20)])

In [None]:
train.shape

In [None]:
m_types = [('l',1),('l',2),('l',3),('l',4),('l',5),
           ('h',1),('h',2),('h',3),('h',4),('h',5),('h',10)]
p_dist = np.zeros((len(m_types),11))
for _ in range(100):
    for i,m_type in enumerate(m_types):
        p_dist[i] += np.bincount(get_synth_segment_y(*m_type,SEGMENT_SIZE),minlength=11)/(SEGMENT_SIZE)
p_dist /= 100
ts_densities = np.array(
    [np.bincount(train[-1,b].squeeze(),minlength=11)/SEGMENT_SIZE for b in range(9)])

In [None]:
ts_densities

In [None]:
from scipy.spatial import distance
t_types = []
for d in distance.cdist(ts_densities,p_dist): t_types.append(m_types[d.argmin()])
t_types

In [None]:
test[0,...,0].flatten().shape,open_channels.shape

In [None]:
plt.axvline(100_000, -5, 10, label='pyplot vertical line')
plt.plot(test[0,...,0].flatten()[:2000000//3])
plt.plot(open_channels[:2000000//3]+8)


In [None]:
plt.plot(test[0,...,0].flatten()[2000000//3:])
plt.plot(open_channels[2000000//3:]+8)

In [None]:
train_types = [('l', 1),
 ('l', 1),
 ('h', 1),
 ('h', 3),
 ('h', 10),
 ('h', 5),
 ('h', 1),
 ('h', 5),
 ('h', 10)]

In [None]:
train_types_std = defaultdict(list)
for s,tt in enumerate(train_types):
    xs = train[ 0,s,:].flatten()
    ys = train[-1,s,:].flatten()
    print(ys.shape,xs.shape)
    x_pred = get_synth_segment_xy(None,None,size=SEGMENT_SIZE,y=ys,add_noise=False)[0].flatten()
    plt.plot(xs-x_pred)
    std = (xs-x_pred).std().item()
    train_types_std[tt].append(std)
    #plt.plot(x_pred)

    train_types_std

In [None]:
test_types_std = defaultdict(list)
st,sl=5000,10000
for s in [1]:
    xs = test_dl.dataset.data[0,s,:]
    ys = open_channels.view(-1,TEST_SEGMENT_SIZE)[s]
    x_pred = get_synth_segment_xy(None,None,size=TEST_SEGMENT_SIZE,y=ys,add_noise=False)[0]
    residual = xs-x_pred
    l = slice(st,st+sl)
    plt.plot(xs[l])
    plt.plot(x_pred[l]-2)
    plt.plot(residual[l])
    std = (residual).std().item()
    print(std)
    test_types_std[tt].append((xs-x_pred).std())
    #plt.plot(x_pred)
residual.abs().argmax()

In [None]:
test_types_std

In [None]:
csv_fname = f'{modelname}_n_tta{n_tta}.csv';csv_fname

In [None]:
dd = test.flatten().clone() # torch.empty_like(test_dl.dataset.data.flatten())
dd[1:] = test_dl.dataset.data.flatten()[:-1]
dd[0]  = test_dl.dataset.data.flatten()[-1]
test_dl.dataset.data =  dd.view(test_dl.dataset.data.shape)
test_dl.dataset.data[0,0,:]

In [None]:
open_channels_shifted = open_channels.clone()
open_channels_shifted[:-1] = open_channels[1:]
open_channels_shifted[-1]  = open_channels[0]

In [None]:
plt.plot(open_channels_shifted[:TEST_SEGMENT_SIZE])
plt.plot(test.flatten()[:TEST_SEGMENT_SIZE])

In [None]:
submission_csv_path = p_input / 'sample_submission.csv'
ss = pd.read_csv(submission_csv_path, dtype={'time': str})
test_preds_all = test_preds
test_pred_frame = pd.DataFrame({'time': ss['time'].astype(str), 'open_channels': open_channels_shifted})
test_pred_frame.to_csv(csv_fname, index=False)

In [None]:
!kaggle competitions submit -c 'liverpool-ion-switching' -f {csv_fname} -m 'PU {public_cv} PR {private_cv} TR {train_cv}'