## Helpful function

In [1]:
import os

import json
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import multiprocessing as mp
from config import *
def load_relevant_data_subset(pq_path):
    data_columns = col
    data = pd.read_parquet(pq_path, columns=data_columns)
    return data.values.astype(np.float32)
import time
def read_dict(file_path):
    path = os.path.expanduser(file_path)
    with open(path, "r") as f:
        dic = json.load(f)
    return dic
class TrainingLogger:
    def __init__(self, log_file_path, headers, sep='    | ',print_to_stdout=True,p=3,resume=False):
        self.log_file_path = log_file_path
        self.headers = headers
        self.sep = sep
        self.print_to_stdout = print_to_stdout
        self.col_widths = [len(h) for h in headers]
        self.start_time = time.time()
        self.elapsed_time = 0
        self.elapsed_time_offset=0
        self.resume = resume
        self.first_log = True
        self.p = p
    def _write_headers(self):
        with open(self.log_file_path, 'w') as f:
            header_str = self.sep.join(self.headers) + ' | Time Elapsed'
            f.write(header_str + '\n')
        if self.print_to_stdout:
            print(header_str)
    def _get_padded_strings(self, data):
        padded_strings = []
        for i, value in enumerate(data):
            if isinstance(value, str):
                padded_string = value.ljust(self.col_widths[i])
            else:
                padded_string = str(round(value,self.p)).ljust(self.col_widths[i])
            padded_strings.append(padded_string)
        return padded_strings
    def log(self, *args):
        if self.first_log:
            self._write_headers()
            self.first_log = False
        assert len(args) == len(self.headers), f'Length of arguments should be {len(self.headers)}.'
        with open(self.log_file_path, 'a') as f:

            elapsed_time = time.time() - self.start_time+self.elapsed_time_offset
            self.elapsed_time = elapsed_time
            time_str = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
            data_strings = self._get_padded_strings(args) + [time_str]
            log_str = self.sep.join(data_strings)
            f.write(log_str + '\n')
        if self.print_to_stdout:
            print(log_str)
    def info(self,message):
        #记录寻常的日志信息
        with open(self.log_file_path, 'a') as f:
            f.write(message + '\n')
        if self.print_to_stdout:
            print(message)

num_point1:  135 num_point2:  408 layerdrop_rate:  0.0 used_folds [2]


## Function for training

In [2]:
import random
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    #the following line gives ~10% speedup
    #but may lead to some stochasticity in the results 
    torch.backends.cudnn.benchmark = True
def resume_fn(fname,model,swa_model,optim,scheduler):
    if os.path.exists(fname): 
        data = torch.load(fname, map_location="cpu")
        torch.set_rng_state(data['torch_rng_state'])
        torch.cuda.set_rng_state(data['cuda_rng_state'])
        np.random.set_state(data['numpy_rng_state'])
        random.setstate(data['random_rng_state'])
        model.load_state_dict(data['state_dict'], strict=True)
        swa_model.load_state_dict(data['swa_model'])
        optim.load_state_dict(data['optimizer'])
        scheduler.load_state_dict(data['scheduler'])
        start_epoch = data['epoch'] + 1
        val_score = data['val_score']
        best_score = data['best_score']
        patience = data['patience']
        elapsed_time_offset = data['elapsed_time_offset']
        print(f" resume from {fname} \n at epoch {start_epoch-1} \n with val_score {val_score} best_score {best_score} and elapsed_time_offset {elapsed_time_offset}")
        return model,swa_model, optim, scheduler, start_epoch, best_score, patience,elapsed_time_offset
    else:
        print('no checkpoint found at %s', fname)
        assert False

## Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# class Model(nn.Module):
#     def __init__(self) -> None:
#         super().__init__()
#         self.encoder = timm.create_model('tf_efficientnet_b0_ns', pretrained=True,in_chans=2*len(pointset1),num_classes=250,drop_rate=dropout)
#     def forward(self, x):
#         x = torch.cat([x[:,:,:,:,0],x[:,:,:,:,1]],dim=1)
#         x = self.encoder(x)
#         return x

import torch
from torch import nn
from torch.nn import functional as F
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)
class GELU(nn.Module):
    def forward(self, x):
        return 0.5*x*(1+torch.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
act={ "quickgelu":QuickGELU,"gelu":GELU,"relu":nn.ReLU}



class MultiHeadAttention(nn.Module):
    def __init__(self, d_model,h):
        super().__init__()
        d_k = d_model // h//1
        d_v = d_model // h
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h
    def forward(self, queries, keys, values, attention_mask=None):
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]
        q = self.fc_q(queries).reshape(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).reshape(keys.shape[0], nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).reshape(values.shape[0], nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)
        att = torch.matmul(q, k) / (self.d_k**0.5)  # (b_s, h, nq, nk)
        if attention_mask is not None:
            att = att+attention_mask*torch.finfo(torch.float32).min
        att = torch.softmax(att, -1)
        att = F.dropout(att, p=dropout, training=self.training)
        out = torch.matmul(att, v).permute(0, 2, 1, 3).reshape(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)
        return out

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_model)
        self.fc2 = nn.Linear(d_model, d_model)
        self.relu = act[act_name]()
    def forward(self, input):
        out = self.fc1(input)
        out = self.relu(out)
        out = F.dropout(out, p=dropout, training=self.training)
        out = self.fc2(out)
        out = F.dropout(out, p=dropout, training=self.training)
        return out
    
class TransformerBlock(nn.Module):
    def __init__(self,
        d_model,
        num_head,
    ):
        super().__init__()
        self.attn  = MultiHeadAttention(d_model, num_head)
        # self.ffn   = PositionWiseFeedForward(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, x_mask=None):
        out = self.norm1(x)
        out = x + F.dropout(self.attn(out,out,out, x_mask), p=dropout, training=self.training)
        out = self.norm2(out)
        return out
#########################################################################
 
#########################################################################

class MeanPooling(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    def forward(self,hidden_state, attention_mask):

        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(hidden_state.size())
        )
        mean_embeddings = torch.sum(hidden_state * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9)
        return mean_embeddings

num_main_feats = len(LHAND)*2*2*n+num_point1*2
class TransformerEmbedding(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = nn.Sequential(nn.Linear(input_dim[0]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc2 = nn.Sequential(nn.Linear(input_dim[0]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc3 = nn.Sequential(nn.Linear(input_dim[1]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc4 = nn.Sequential(nn.Linear(input_dim[1]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc5 = nn.Sequential(nn.Linear(input_dim[2]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc6 = nn.Sequential(nn.Linear(input_dim[2]//2,hidden_dim),nn.LayerNorm(hidden_dim),act[act_name](),nn.Dropout(dropout)
                                )
        self.fc = nn.Sequential(nn.Linear(6*hidden_dim,embed_dim),nn.LayerNorm(embed_dim),act[act_name](),nn.Dropout(dropout),
                                # nn.Linear(embed_dim,embed_dim),nn.LayerNorm(embed_dim),act[act_name](),nn.Dropout(dropout)
                                )
    def forward(self,inputs):
        x1 = inputs[:,:,:input_dim[0]]
        x2 = inputs[:,:,input_dim[0]:input_dim[0]+input_dim[1]]
        x3 = inputs[:,:,input_dim[0]+input_dim[1]:]
        x1 = torch.cat([self.fc1(x1[...,:input_dim[0]//2]),self.fc2(x1[...,input_dim[0]//2:])],dim=-1)
        x2 = torch.cat([self.fc3(x2[...,:input_dim[1]//2]),self.fc4(x2[...,input_dim[1]//2:])],dim=-1)
        x3 = torch.cat([self.fc5(x3[...,:input_dim[2]//2]),self.fc6(x3[...,input_dim[2]//2:])],dim=-1)
        x = torch.cat([x1,x2,x3],dim=-1)
        x = self.fc(x)
        return x

class Model(nn.Module):
    def __init__(self, num_class=250):
        super().__init__()
        self.layerdrop=nn.Parameter(torch.tensor(layerdrop), requires_grad=False)
        self.layerdrop_decay=None
        
        self.x_embed = TransformerEmbedding()
        self.pos_embed = nn.Parameter(torch.zeros((max_length, embed_dim)), requires_grad=True)
        self.norm = nn.LayerNorm(embed_dim)
        self.encoder=TransformerBlock(embed_dim, num_head)
        self.logit = nn.Sequential( nn.Linear(embed_dim, num_class))
        self._init_weights()    
    def _init_weights(self):
        #init weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
    def forward(self, x,x_mask):
      
        B,L= x.shape[:2]
        x = self.x_embed(x)
        x = x + self.pos_embed[:L].unsqueeze(0)
        x = self.norm(x)
        x_mask = x_mask.unsqueeze(1).unsqueeze(2)
        if   self.training and random.random() < self.layerdrop:
            pass
        else:
            x = self.encoder(x,x_mask)
        mean_pool = MeanPooling()(x,1-x_mask[:,0,0,:])
        mean_pool  =F.dropout(mean_pool, p=dropout, training=self.training)
        logit = self.logit(mean_pool)
        # logit = logit.reshape(B, -1)
        return logit

In [4]:
torch.save(Model().state_dict(), 'model.pth')
Model()

Model(
  (x_embed): TransformerEmbedding(
    (fc1): Sequential(
      (0): Linear(in_features=135, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): QuickGELU()
      (3): Dropout(p=0.3, inplace=False)
    )
    (fc2): Sequential(
      (0): Linear(in_features=135, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): QuickGELU()
      (3): Dropout(p=0.3, inplace=False)
    )
    (fc3): Sequential(
      (0): Linear(in_features=672, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): QuickGELU()
      (3): Dropout(p=0.3, inplace=False)
    )
    (fc4): Sequential(
      (0): Linear(in_features=672, out_features=128, bias=True)
      (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (2): QuickGELU()
      (3): Dropout(p=0.3, inplace=False)
    )
    (fc5): Sequential(
      (0): Linear(in_features=882, out_features

In [5]:
# njkb

NameError: name 'njkb' is not defined

## DataSet

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# def pre_process(xyz):
#     xyz = xyz - xyz[~torch.isnan(xyz)].mean(0,keepdims=True) #noramlisation to common mean
#     xyz = xyz / xyz[~torch.isnan(xyz)].std(0, keepdims=True)
#     xyz[torch.isnan(xyz)] = 0
#     xyz = center_crop(xyz, len(xyz))
#     return xyz
class aug():
    def __init__(self,p=0.):
        self.p = p
        self.fd = 0.4
    def __call__(self,sample):
        sample = self.flip(sample)
        sample = self.random_drop(sample)
        return sample
    def random_drop(self,sample):
        if self.p>0:
            indices=[]
            while len(indices) ==0:
                indices = (torch.rand(sample.shape[0]) >=self.fd).nonzero().squeeze(1)
            sample = sample[indices]
        return sample
    def flip(self,sample):
        if torch.rand(1) < self.p:
            x = sample[:,:,0]
            x_max = x[~torch.isnan(x)].max()
            x_min = x[~torch.isnan(x)].min()
            x = x_max - x + x_min
            sample[:,:,0] = x
            tmp = sample[:,LHAND]
            sample[:,LHAND] = sample[:,RHAND]
            sample[:,RHAND] = tmp
        return sample
def norm(x):
    x_mean, x_std = x[~torch.isnan(x)].mean(0,keepdim=True), x[~torch.isnan(x)].std(0,keepdim=True)
    x = (x - x_mean) / x_std
    x[torch.isnan(x)] = 0
    return x
def norm_xy(x):
    return torch.cat([norm(x[...,0:1]),norm(x[...,1:2])],dim=-1)
def pre_process(x):
    for i in range(x.shape[-1]):
        x[:,:,i] = norm(x[:,:,i])
    x = center_crop(x, len(x))
    return x
def center_crop(xyz, valid_len):#input shape (none,543,2)
    if valid_len > max_length:
        start = (valid_len - max_length) // 2
        end = start + max_length
        xyz = xyz[start:end]
    return xyz

class SignDataset(Dataset):
    def __init__(self, df, augment=None):
        self.df = df
        self.augment = augment
        self.length = len(self.df)
    def __len__(self):
        return self.length
    def __getitem__(self, index):
        d = self.df.iloc[index]

        pq_file =f'/mnt/hdd1/wangjingqi/GD/{d.path}'
        x = load_relevant_data_subset(pq_file)
        
        n_frames = int(len(x) / ROWS_PER_FRAME)
        x = x.reshape(n_frames, ROWS_PER_FRAME, len(col))
        # x = x[:,pointset1]
        x = torch.from_numpy(x).float()
        if self.augment is not None:
            x = self.augment(x)
        xy = norm_xy(x)[:,pointset1]
        xy = center_crop(xy, len(xy))
        x = x[:,pointset1]
        x = center_crop(x, len(x))

        lhand = x[:,len(LIP):len(LIP)+len(LHAND)]
        rhand = x[:,len(LIP)+len(LHAND):]
        relative_lhand =norm_xy (lhand.unsqueeze(1) - lhand.unsqueeze(2))
        relative_rhand =norm_xy (rhand.unsqueeze(1) - rhand.unsqueeze(2))
        x = x[:,len(LIP):]
        x = x.permute(1,0,2)#shape (num_point1, max_length, 2)
        relative_back = torch.zeros(x.shape[0],x.shape[1],n//2,2)
        relative_front = torch.zeros(x.shape[0],x.shape[1],n//2,2)
        for i in range(n//2):
            off =x[:,i+1:]-x[:,:-i-1]
            relative_back[:,i+1:,i] = off
            relative_front[:,:-i-1,i] = -off
        relative = torch.cat([relative_back,relative_front],-2)#shape (len(LHAND)*2, max_length, n, 2)
        relative_xy = norm_xy(relative).permute(1,0,2,3)
        # relative_xy = torch.cat([norm_xy(relative[:len(LHAND)]),norm_xy(relative[len(LHAND):])],dim=0).permute(1,0,2,3)
        x = xy[...,0:1].flatten(-2)#shape (max_length, num_point1*2)
        y = xy[...,1:2].flatten(-2)
        relative_x = relative_xy[...,0:1].flatten(-3)#shape (max_length,len(LHAND)*2*n*2)
        relative_y = relative_xy[...,1:2].flatten(-3)
        relative_lhand_x =  relative_lhand[...,0:1].flatten(-3)#shape (max_length,len(LHAND)*len(LHAND)*2)
        relative_lhand_y =  relative_lhand[...,1:2].flatten(-3)
        relative_rhand_x =  relative_rhand[...,0:1].flatten(-3)#shape (max_length,len(LHAND)*len(LHAND)*2)
        relative_rhand_y =  relative_rhand[...,1:2].flatten(-3)
        x = torch.cat([x,y,relative_x,relative_y,relative_lhand_x,relative_rhand_x,relative_lhand_y,relative_rhand_y],-1)
        sample ={"x":x,"label":d.label}
        return sample
def pack_seq(
    seq,
):
    max_len = max([ s.shape[0] for s in seq])
    batch_size = len(seq)
    x = torch.zeros(batch_size,max_len,seq[0].shape[-1])
    x_mask = torch.zeros((batch_size, max_len))
    for i in range(batch_size):
        x[i,:seq[i].shape[0]] = seq[i]
        x_mask[i, seq[i].shape[0]:] = 1
    
    return x, x_mask

def null_collate(batch):
    d = {}
    key = batch[0].keys()
    for k in key:
        d[k] = [b[k] for b in batch]
    d['label'] = torch.LongTensor(d['label'])
    return d


## Validation

In [None]:
from tqdm import tqdm
from sklearn.metrics import f1_score
DEBUG = 0

import torch
import torch.nn as nn

def loss_fn(preds, labels):
    loss = nn.CrossEntropyLoss(label_smoothing=ls)(preds, labels)
    return loss
def compute_score(preds, labels):
    preds = preds.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    predict = np.argsort(-preds,-1)
    correct = predict==labels.reshape(labels.shape[0],1)
    topk = correct.cumsum(-1).mean(0)[:5]
    return topk

def train_fn(model, train_loader, optim, scaler,device):
    model.train()
    train_loss = 0
    train_score = 0
    with tqdm(desc='train', total=len(train_loader)) as pbar:
        for it,batch in enumerate(train_loader):
            batch["x"],batch["mask"]=pack_seq(batch["x"])
            batch['x'] = batch['x'].to(device[0])
            batch['mask'] = batch['mask'].to(device[0])
            batch['label'] = batch['label'].to(device[0])
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled = True):
                preds = model(batch['x'],batch['mask'])
                if preds.shape[0]!=batch["label"].shape[0]:
                    batch["label"]=torch.cat([batch["label"],batch["label"]],0)
                loss = loss_fn(preds, batch['label'])
                score = compute_score(preds, batch['label'])
                train_loss += loss.item()
                train_score += score
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            pbar.set_postfix_str(f'loss={train_loss/ (it + 1):.4f}')
            pbar.update()
            if DEBUG:
                break
    train_loss /= len(train_loader)
    train_score /= len(train_loader)
    return train_loss, train_score
def validate_fn(model, valid_loader,device):
    model.eval()
    val_loss = 0
    val_score = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        with tqdm(desc='valid', total=len(valid_loader)) as pbar:
            for it,batch in enumerate(valid_loader):
                batch["x"],batch["mask"]=pack_seq(batch["x"])
                batch['x'] = batch['x'].to(device[0])
                batch['mask'] = batch['mask'].to(device[0])
                batch['label'] = batch['label'].to(device[0])
                preds = model(batch['x'],batch['mask'])
                all_preds.append(preds)
                all_labels.append(batch['label'])
                pbar.update()
                if DEBUG:
                    break
            all_preds = torch.cat(all_preds,0)
            all_labels = torch.cat(all_labels,0)
            val_loss = loss_fn(all_preds, all_labels)
            val_score = compute_score(all_preds, all_labels)
    return val_loss.item(), val_score

## Train Loop

In [None]:
import shutil

import gc
from torch.optim.swa_utils import AveragedModel, SWALR
def train_one_fold(fold,exp_name,model_path,train_loader,valid_loader,model,optim,epochs,resume_last,resume_best,device,headers):
    best_score = .0
    patience = 0
    start_epoch = 0
    best_epoch = 0
    exp_name = f"{exp_name}_fold{fold}"
    log = TrainingLogger(f'{model_path}/log{exp_name}.txt', headers= headers)
    scaler = torch.cuda.amp.GradScaler(enabled = True)
    swa_model = AveragedModel(model)
    swa_start = 40
    scheduler = SWALR(optim, swa_lr=1e-4,anneal_epochs=1,anneal_strategy="linear")
    if resume_last or resume_best:
        if resume_last:
            fname = os.path.join(model_path,"last",f"{exp_name}_last.pth")
        else:
            fname = os.path.join(model_path, "best", f"{exp_name}_best.pth")
        if os.path.exists(fname):
            model,swa_model, optim, scheduler, start_epoch, best_score, patience,elapsed_time_offset = resume_fn(fname,model,swa_model,optim,scheduler)
            log.elapsed_time_offset =elapsed_time_offset
            log.first_log = False
        else:
            print("Warning: resume file not found")
    print(f"Training {exp_name} from epoch {start_epoch}")
    
    model = nn.DataParallel(model, device_ids=device, output_device=device[0])
    
    for e in range(start_epoch, epochs):
        torch.cuda.empty_cache()
        gc.collect()
        # Training
        train_loss, train_score = train_fn(model, train_loader, optim, scaler,device)
        model.module.layerdrop -=model.module.layerdrop_decay
        print(f"layerdrop:{model.module.layerdrop}")
        # Validation
        if e > swa_start:
            swa_model.update_parameters(model)
            scheduler.step()
            print("swa update",scheduler.get_last_lr())
            val_loss, val_score = validate_fn(swa_model, valid_loader,device)
        else:
            val_loss, val_score = validate_fn(model, valid_loader,device)
        train_top0, train_top1, train_top2, train_top3, train_top4 = train_score[0], train_score[1], train_score[2], train_score[3], train_score[4]
        val_top0, val_top1, val_top2, val_top3, val_top4 = val_score[0], val_score[1], val_score[2], val_score[3], val_score[4]
        log.log(e, train_loss, train_top0, val_loss, val_top0, val_top1, val_top2, val_top3, val_top4)
        # log.log(e, train_loss, train_top0, val_loss, val_top0, val_top1, val_top2, val_top3, val_top4,str(scheduler.get_lr()[0]))
        # Prepare for next epoch
        
        best=False
        if val_top0 >= best_score:
            best_score = val_top0
            best_epoch = e
            log.info(f"Better score {best_score:.4f} is found at epoch {e}")
            patience = 0
            best = True
        else:
            patience += 1
        torch.save({
            'torch_rng_state': torch.get_rng_state(),
            'cuda_rng_state': torch.cuda.get_rng_state(),
            'numpy_rng_state': np.random.get_state(),
            'random_rng_state': random.getstate(),
            'epoch': e,
            "fold": fold,
            "exp_name":exp_name,
            'val_loss': val_loss,
            'val_score': val_score,
            "elapsed_time_offset": log.elapsed_time,
            'state_dict': model.module.state_dict(),
            "swa_state_dict":swa_model.module.state_dict(),
            "swa_model":swa_model.state_dict(),
            'optimizer': optim.state_dict(),
            'scheduler': scheduler.state_dict(),
            'patience': patience,
            'best_score': best_score,
        }, os.path.join(model_path, "last",f"{exp_name}_last.pth"))

        if best:
            shutil.copyfile(os.path.join(model_path, "last",f"{exp_name}_last.pth"), os.path.join(model_path, "best",f"{exp_name}_best.pth"))
        # if patience >= patiences:
        #     log.info(f"Early stopping at epoch {e} with best score {best_score:.4f} at epoch {best_epoch}")
        #     break
    log.info(f"train done for fold {fold} with best score {best_score:.4f} at epoch {best_epoch} \n")
    return best_score

## Config

In [None]:
lr = 1e-4
epochs = 200
bs = 64
patiences = 10
n_fold = 5
device = device
resume_last = True
resume_best = False
exp_name = f"n{n}_seed{seed}_{act_name}_fd{int(fd*10)}_p{int(p*10)}_ls{int(ls*10)}_pn{num_point1}_{num_block}_{embed_dim}_{hidden_dim}_{bs}_ld{int(layerdrop*10)}"
# exp_name = "test"
model_path = f'/mnt/hdd1/wangjingqi/GD/asl/{exp_name}'

HEADERS = ['epoch', 'train_loss',"train_top0", 'val_loss', 'val_top0', 'val_top1', 'val_top2', 'val_top3', 'val_top4']

os.makedirs(os.path.join(model_path,"last"), exist_ok=True)
os.makedirs(os.path.join(model_path,"best"), exist_ok=True)


In [None]:
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv("/mnt/hdd1/wangjingqi/GD/train.csv")
label_index = read_dict(f"/mnt/hdd1/wangjingqi/GD/sign_to_prediction_index_map.json")
index_label = dict([(label_index[key], key) for key in label_index])
df["label"] = df["sign"].map(lambda sign: label_index[sign])
groups = df["path"].map(lambda x: x.split("/")[1])
sgkf = StratifiedGroupKFold(n_splits=n_fold, random_state=seed, shuffle=True)
df["fold"] = -1
for i, (train_index, valid_index) in enumerate(sgkf.split(df["path"],df["label"], df["participant_id"])):
    df.loc[valid_index, "fold"] = i
df.to_csv(f"/mnt/hdd1/wangjingqi/GD/train_fold.csv", index=False)
# df = pd.read_csv("/mnt/hdd1/wangjingqi/GDtrain_prepared.csv")

## Train 

In [None]:

all_best_score = {}
def get_optimizer_params(model, lr):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
            "lr" : lr
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
            "lr" : lr
        }
    ]
    return optimizer_grouped_parameters
for fold in used_folds:
    seed_everything(seed)
    model =  Model().to(device[0])
    
    print(model)
    if optim_type == "adamw":
        optim = torch.optim.AdamW(params=get_optimizer_params(model,lr))
    train_df = df[df.fold != fold].reset_index(drop=True)
    valid_df = df[df.fold == fold].reset_index(drop=True)
    train_dataset = SignDataset(train_df,aug(p))
    valid_dataset = SignDataset(valid_df,None)
    print(np.sum(input_dim),valid_dataset[0]["x"].shape[-1])
    train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=nw, pin_memory=True, drop_last=True,collate_fn=null_collate)
    valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=True, drop_last=False,collate_fn=null_collate)
    model.layerdrop_decay=layerdrop/epochs
    best_score_fold = train_one_fold(fold,exp_name,model_path,train_loader,valid_loader,model,optim,epochs,resume_last,resume_best,device,HEADERS)
    all_best_score[f"fold{fold}"] = best_score_fold
    resume_best = False
    resume_last = False
    # break

In [None]:

print("best socre for all fold is ",all_best_score)
with open("/mnt/hdd1/wangjingqi/GD/asl/log.txt","a") as f:
    f.write(f"{exp_name} with epoch {epochs} lr {lr} bs {bs}  seed {seed} max_length {max_length} embed_dim {embed_dim} num blocks {num_block} num_heads {num_head}  \n for all fold is {all_best_score} \n")
