# Import Library MM

In [1]:
import sys
sys.path.append('../data')
import os
import gc
import re
import math
import time
import random
import shutil
import pickle
import itertools
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter
from torch.nn.parallel import DistributedDataParallel as DDP
import scipy as sp
import numpy as np
import pandas as pd
from tqdm import tqdm
# from tqdm.auto import tqdm
import evaluation
from evaluation import PTBTokenizer, Cider

import Levenshtein
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

from functools import partial
from torch.optim.lr_scheduler import LambdaLR
import cv2
from PIL import Image

import torch
import torch.multiprocessing as multiprocessing
#multiprocessing.set_start_method('spawn')
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.nn import NLLLoss
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory, ScaledDotProductAttention
from data import RawField

import warnings 
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

# Config

In [2]:
num_workers=16
start_lr = 1e-2
seed = 3211
n_fold=5
transform_size = 224
preprocess = 'cnn' # direct, rcnn, cnn
batch_size = 1
beam_size = 3
seq_len = 280

# CNN d_in= 2048, RCNN d_in = 1024, direct d_in = 2352
if preprocess == 'direct':
    m2_d_in = 2352
elif preprocess == 'rcnn':
    m2_d_in = 1024
elif preprocess == 'cnn':
    m2_d_in = 256# 1000 if no fc, else 256


def lambda_lr(s):
    warm_up = 10000
    s += 1
    return (512 ** -.5) * min(s ** -.5, s * warm_up ** -1.5)

def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def bms_collate(batch):
    imgs, labels, label_lengths = [], [], []
    for data_point in batch:
        imgs.append(data_point[0])
        labels.append(data_point[1])
        label_lengths.append(data_point[2])
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi["<pad>"])
    return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)

seed_torch(seed=seed)

# Tokenizer

In [3]:

class Tokenizer(object):
    
    def __init__(self):
        self.stoi = {}
        self.itos = {}

    def __len__(self):
        return len(self.stoi)
    
    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split(' '))
        vocab = sorted(vocab)
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        
    def text_to_sequence(self, text):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for s in text.split(' '):
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences

    def sequence_to_text(self, sequence):
        return ''.join(list(map(lambda i: self.itos[int(i)], sequence)))
    
    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts
    
    def predict_caption(self, sequence):
        #print('sequence',sequence)
        #print('sequence shape',sequence.shape)
        caption = ''
        for i in sequence:
            if i == self.stoi['<sos>'] :
                continue
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            #print('i', i)
            caption += self.itos[int(i)]
        return caption
    
    def predict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions

tokenizer = torch.load('tokenizer.pth')
text_field = tokenizer
print(f"tokenizer.stoi: {tokenizer.stoi}")

tokenizer.stoi: {'(': 0, ')': 1, '+': 2, ',': 3, '-': 4, '/b': 5, '/c': 6, '/h': 7, '/i': 8, '/m': 9, '/s': 10, '/t': 11, '0': 12, '1': 13, '10': 14, '100': 15, '101': 16, '102': 17, '103': 18, '104': 19, '105': 20, '106': 21, '107': 22, '108': 23, '109': 24, '11': 25, '110': 26, '111': 27, '112': 28, '113': 29, '114': 30, '115': 31, '116': 32, '117': 33, '118': 34, '119': 35, '12': 36, '120': 37, '121': 38, '122': 39, '123': 40, '124': 41, '125': 42, '126': 43, '127': 44, '128': 45, '129': 46, '13': 47, '130': 48, '131': 49, '132': 50, '133': 51, '134': 52, '135': 53, '136': 54, '137': 55, '138': 56, '139': 57, '14': 58, '140': 59, '141': 60, '142': 61, '143': 62, '144': 63, '145': 64, '146': 65, '147': 66, '148': 67, '149': 68, '15': 69, '150': 70, '151': 71, '152': 72, '153': 73, '154': 74, '155': 75, '156': 76, '157': 77, '158': 78, '159': 79, '16': 80, '161': 81, '163': 82, '165': 83, '167': 84, '17': 85, '18': 86, '19': 87, '2': 88, '20': 89, '21': 90, '22': 91, '23': 92, '24': 9

# Data Loading

In [4]:
import numpy as np
import pandas as pd
import torch

train = pd.read_pickle('train.pkl')

def get_train_file_path(image_id):
    return "../data/train/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

train['file_path'] = train['image_id'].apply(get_train_file_path)

#print(f'train.shape: {train.shape}')
#display(train.head())

folds = train.copy()
Fold = StratifiedKFold(n_splits=n_fold, shuffle=True, random_state=seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['InChI_length'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
#print(folds.groupby(['fold']).size())

# Dataset

In [5]:
class TrainDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.file_paths = df['file_path'].values
        self.labels = df['InChI_text'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label)
        label_length = torch.LongTensor([label_length])
        return image, torch.LongTensor(label), label_length
    

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.file_paths = df['file_path'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            Resize(transform_size, transform_size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(transform_size, transform_size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])


# Data Loader

In [6]:
fold = 0
trn_idx = folds[folds['fold'] != fold].index
val_idx = folds[folds['fold'] == fold].index

train_folds = folds.loc[trn_idx].reset_index(drop=True)
valid_folds = folds.loc[val_idx].reset_index(drop=True)
valid_labels = valid_folds['InChI'].values

train_dataset = TrainDataset(train_folds, tokenizer, transform=get_transforms(data='train'))
valid_dataset = TrainDataset(valid_folds, tokenizer,transform=get_transforms(data='valid'))

train_loader = DataLoader(train_dataset, 
                            batch_size=batch_size, 
                            shuffle=True, 
                            num_workers=num_workers, 
                            pin_memory=True,
                            drop_last=True, 
                            collate_fn=bms_collate)
valid_loader = DataLoader(valid_dataset, 
                            batch_size=batch_size, 
                            shuffle=False, 
                            num_workers=num_workers,
                            pin_memory=True, 
                            drop_last=False, 
                            collate_fn=bms_collate)

#cider_train = Cider(PTBTokenizer.tokenize(ref_caps_train))

# Feature Extraction (No use for now)

In [7]:
import torch
import torchvision
import torch.nn as nn

num_region = 100
RCNN = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, rpn_pre_nms_top_n_train=num_region, rpn_pre_nms_top_n_test=num_region, rpn_post_nms_top_n_train=num_region, rpn_post_nms_top_n_test=num_region)   
RCNN.eval()
RCNN.to(device)


def get_feature(x):

    features = {}

    def get_activation(name):
        def hook(model, input, output):
            features[name] = output.detach()
        return hook

    RCNN.roi_heads.box_head.register_forward_hook(get_activation('feature'))
    out = RCNN(x)
    #print(features['feature'].shape)
    return features['feature'].view(-1,num_region,1024)


In [8]:
import pickle

def save_all_features():

    SAVEPATH = './features/'
    train_data = []
    foo = 0
    with tqdm(desc='Epoch %d - train' % foo, unit='it', total=len(train_loader)) as pbar:
        for i, (detection, caption, length) in enumerate(train_loader):
            detection, caption = detection.to(device), caption.to(device)
            
            detection = get_feature(detection)
            #print('detection.shape', detection.shape)
            #print('caption.shape', caption.shape)
            
            for j in range(detection.shape[0]):
                train_data.append([detection[j,:,:].cpu(), caption[j,:].cpu()])
            pbar.update()
            if i != 0 and i % 100 == 0:
                torch.save(train_data,SAVEPATH + str(i) + '.pt')
                # with open(SAVEPATH + str(i) + '.pickle','wb') as wfp:
                #     pickle.dump(train_data, wfp)
                train_data = []        
#save_all_features()

# Models

In [9]:
CNN = torchvision.models.mobilenet_v2(pretrained=True)
CNN.fc = nn.Identity()
#CNN.eval()
cnn = CNN.to(device)
#print(CNN)

class fcNet(nn.Module):
    def __init__(self):
      super(fcNet, self).__init__()
      self.fc1 = nn.Linear(1000, 4096)
      self.fc2 = nn.Linear(4096, 10240)

    # x represents our data
    def forward(self, x):
      x = self.fc1(x)
      x = F.relu(x)
      x = self.fc2(x)
      x = F.relu(x)
      x = x.view(-1, 40, 256)
      return x
fc = fcNet()
# m = 40
# encoder = MemoryAugmentedEncoder(3, 0, d_in = m2_d_in, attention_module=ScaledDotProductAttentionMemory,
#                                     attention_module_kwargs={'m': m})
# decoder = MeshedDecoder(len(tokenizer), 300, 3, tokenizer.stoi['<pad>'])
# model = Transformer(tokenizer.stoi['<sos>'], encoder, decoder, cnn).to(device)
# CNN d_in= 2048, RCNN d_in = 1024, direct d_in = 2352
encoder = MemoryAugmentedEncoder(3, 0, d_in = m2_d_in, attention_module=ScaledDotProductAttention)
decoder = MeshedDecoder(len(tokenizer), 300, 3, tokenizer.stoi['<pad>'])
model = Transformer(tokenizer.stoi['<sos>'], encoder, decoder, cnn, fc).to(device)

# if torch.cuda.device_count() > 1:
#     model = nn.DataParallel(model)
#     print('use two')
optim = Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100000)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=30, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
#scheduler = LambdaLR(optim, lambda_lr)
loss_fn = NLLLoss(ignore_index=tokenizer.stoi['<pad>'])


# Train and Validation functions

In [10]:
def image_preprocess(detections):
    if preprocess == 'direct':
        detections = torch.reshape(detections, (detections.shape[0], 3, 28, 8, 224))
        detections = torch.reshape(detections, (detections.shape[0], 3, 28, 8, 28, 8))
        detections = torch.moveaxis(detections, 3, -1)
        detections = torch.reshape(detections, (detections.shape[0], -1, 64))
        detections = torch.moveaxis(detections, 1, -1)
        return detections

    elif preprocess == 'cnn':

        return detections

    elif preprocess == 'rcnn':
        with torch.no_grad():
            detections = get_feature(detections)
        return detections

def get_score(y_true, y_pred):
    scores = []
    # y_true = y_true.cpu()
    # y_pred = y_pred.cpu()
    for true, pred in zip(y_true, y_pred):
        print('true', true)
        print('pred', pred)
        score = Levenshtein.distance(true, pred)
        print('t score', score)
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score

def get_score_mp(caps_gt, caps_gen):
    caps_gt = text_field.predict_caption(caps_gt)
    caps_gen = text_field.predict_caption(caps_gen)
    #print('caps_gt', caps_gt)
    #print('caps_gen', caps_gen)
    score = Levenshtein.distance(caps_gt, caps_gen)
    #print('score', score)
    #return score
    return -score

def train_fn(e, model, dataloader, optim, tokenizer):
    # Training with cross-entropy
    model.train()
    #scheduler.step()
    running_loss = .0
    with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
        for it, (detections, captions, label_length) in enumerate(dataloader):
            detections, captions = detections.to(device), captions.to(device)
            #detections = image_preprocess(detections)
            
            #print(detections.shape)
            out = model(detections, captions)
            optim.zero_grad()
            captions_gt = captions[:, 1:].contiguous()
            out = out[:, :-1].contiguous()
            loss = loss_fn(out.view(-1, len(tokenizer)), captions_gt.view(-1))
            loss.backward()

            optim.step()
            this_loss = loss.item()
            running_loss += this_loss
            pbar.set_postfix(loss=running_loss / (it + 1), this_loss = this_loss,lr =(optim.param_groups[0]['lr']))
            pbar.update()
            scheduler.step()


    loss = running_loss / len(dataloader)
    return loss


def train_scst(e, model, dataloader, optim, text_field):
    # Training with self-critical
    tokenizer_pool = multiprocessing.Pool()
    running_reward = .0
    running_reward_baseline = .0
    model.train()
    running_loss = .0
    

    with tqdm(desc='Epoch %d - train' % e, unit='it', total=len(dataloader)) as pbar:
        for it, (detections, caps_gt, label_length) in enumerate(dataloader):
            detections = detections.to(device)
            outs, log_probs = model.beam_search(detections, seq_len, text_field.stoi['<eos>'],
                                                beam_size, out_size=beam_size)
            optim.zero_grad()

            
            # Rewards
            caps_gt, outs = caps_gt.cpu(), outs.cpu()
            caps_gt = list(itertools.chain(*([c, ] * beam_size for c in caps_gt)))
            caps_gen = outs.view(-1, seq_len)
            
            #----------------------------------------------------
            # reward = []
            # for i in range(len(caps_gt)):
            #     reward.append(get_score_mp(caps_gt[i],caps_gen[i]))
            #----------------------------------------------------
            reward = tokenizer_pool.starmap(get_score_mp, zip(caps_gt, caps_gen))
            #----------------------------------------------------
            reward = np.array(reward)

            reward = reward.astype(np.float32)
            reward = torch.from_numpy(reward).to(device).view(detections.shape[0], beam_size)
            #reward_baseline = torch.mean(reward, -1, keepdim=True)
            #loss = -torch.mean(log_probs, -1) * (reward - reward_baseline)
            loss = torch.mean(log_probs, -1) * (reward)

            loss = loss.mean()
            loss.backward()
            optim.step()
            scheduler.step()

            running_loss += loss.item()
            running_reward += reward.mean().item()
            #running_reward_baseline += reward_baseline.mean().item()
            pbar.set_postfix(loss=running_loss / (it + 1), reward = torch.mean(reward), reward_m=running_reward / (it + 1), this_loss = loss.item(),lr =(optim.param_groups[0]['lr']))
            pbar.update()
            # bp = 250
            # if it > bp:
            #     break

    loss = running_loss / len(dataloader)
    reward = running_reward / len(dataloader)
    # reward_baseline = running_reward_baseline / len(dataloader)
    # loss = running_loss / bp
    # reward = running_reward / bp
    #reward_baseline = running_reward_baseline / len(bp)
    return loss, reward, 0#reward_baseline

def evaluate_loss(model, dataloader, loss_fn, tokenizer):
    # Validation loss
    model.eval()
    running_loss = .0
    with tqdm(desc='Epoch %d - validation' % 0, unit='it', total=len(dataloader)) as pbar:
        with torch.no_grad():
            for it, (detections, captions, label_length) in enumerate(dataloader):
                detections, captions = detections.to(device), captions.to(device)

                detections = image_preprocess(detections)
                
                out = model(detections, captions)
                captions = captions[:, 1:].contiguous()
                out = out[:, :-1].contiguous()
                loss = loss_fn(out.view(-1, len(tokenizer)), captions.view(-1))
                this_loss = loss.item()
                running_loss += this_loss

                pbar.set_postfix(loss=running_loss / (it + 1))
                pbar.update()

    val_loss = running_loss / len(dataloader)
    return val_loss

def evaluate_metrics(model, dataloader, text_field, device): #TODO:
    import itertools
    model.eval()
    model = model.to(device)
    gen = []
    gts = []
    with tqdm(desc='Epoch %d - evaluation' % 0, unit='it', total=len(dataloader)) as pbar:
        for it, (detections, captions, label_length) in enumerate(iter(dataloader)):
            detections, captions = detections.to(device), captions.to(device)
            detections = image_preprocess(detections)
            
            out = model(detections, captions)

            captions_gt = captions[:, 1:].contiguous()
            out = out[:, :-1].contiguous()
            #out = out.view(-1, len(tokenizer))
            #caps_gt = caps_gt.view(-1)
            out = torch.argmax((out), dim = 2)
            #print(out)
            #caps_gen = text_field.sequences_to_texts(out)
            caps_gen = text_field.predict_captions(out)
            #print('before',captions[0])
            captions = text_field.predict_captions(captions)
            for i, (gts_i, gen_i) in enumerate(zip(captions, caps_gen)):
                #gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen.append(gen_i)
                #print('gen', gen_i)
                gts.append(gts_i)
            #print('gts', gts['0_0'])
            #print('gen', gen['0_0'])
            #print('s1', get_score(gts['0_0'], gen['0_0']))
            #print('s1', get_score('asd', '123'))
            pbar.update()
            if it > 1000:
                break
    scores = get_score(gts, gen)
    #print('s2', get_score(gts, gen))
    return scores

def check_result(model, dataloader, optim, tokenizer):
    # Training with cross-entropy
    model.eval()
    for it, (detections, captions, label_length) in enumerate(dataloader):
        detections, captions = detections.to(device), captions.to(device)
        detections = image_preprocess(detections).to(device)
            
        out = model(detections, captions)
        optim.zero_grad()
        captions_gt = captions[:, 1:].contiguous()
        out = out[:, :-1].contiguous()
        

        #print('out', out.shape)
        #print('gt', captions_gt.shape)
        print('pre id', text_field.predict_caption(torch.argmax((out), dim = 2)[0]))
        print('GT', text_field.predict_caption(captions_gt[0]))
        loss = loss_fn(out.view(-1, len(tokenizer)), captions_gt.view(-1))


        break

def get_csv(model, dataloader, optim, tokenizer):
    # Training with cross-entropy
    tokenizer_pool = multiprocessing.Pool()
    model.eval()
    text_preds = []
    with tqdm(desc='Output csv', total=len(dataloader)) as pbar:
        for it, detections in enumerate(dataloader):
            detections = detections.to(device)
            with torch.no_grad():
                predictions, _ = model.beam_search(
                        detections, seq_len, text_field.stoi["<eos>"], beam_size, out_size=1)
            #print('shape',predictions.shape,predictions)
            #predictions = predictions[:, :-1].contiguous()
            #print('pre', tokenizer.predict_caption(predictions[0].detach().cpu().numpy()))
            #             predictions = predictions.detach().cpu().numpy()
            # predictions = predictions[:, :-1].contiguous()
            predictions = [prediction.detach().cpu().numpy() for prediction in predictions]
            #--------------------------------------------------------------------------------#
            for i in range(len(predictions)):
                text_preds.append(tokenizer.predict_caption(predictions[i]))
            #--------------------------------------------------------------------------------#
            # predictions = tokenizer_pool.starmap(tokenizer.predict_caption, zip(predictions))
            # text_preds.append(predictions)
            #--------------------------------------------------------------------------------#

            # if it > 1:
            #     break
            pbar.update()
    return text_preds


# scores = evaluate_metrics(model, valid_loader, tokenizer)
# print(scores)

# Main Loop

In [11]:
model.load_state_dict(torch.load('m2_cnn2_1_1.pt'))


<All keys matched successfully>

In [12]:
# # def lambda_lr(s):
# #     warm_up = 10000
# #     s += 1
# #     return (512 ** -.5) * min(s ** -.5, s * warm_up ** -1.5)

# def lambda_lr(s):
#     s += 1e-5
#     return (512 ** -.5) * (s ** -.5)

# optim = Adam([{'params': model.parameters(), 'initial_lr': 1e-5}], lr=1e-5, betas=(0.9, 0.98))
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 30000, eta_min=0, last_epoch=-1, verbose=False)
# #scheduler = LambdaLR(optim, lambda_lr)
# loss_fn = NLLLoss(ignore_index=tokenizer.stoi['<pad>'])

In [13]:
def start_train_critical():
    for e in range(0, 1, 1):
        train_loss,_,_ = train_scst(0, model, train_loader, optim, tokenizer)
        val_loss = evaluate_loss(model, valid_loader, loss_fn, tokenizer)
        scores = evaluate_metrics(model, valid_loader, tokenizer, device)
        print('train_loss', train_loss) 
        print('val_loss', val_loss)
        print('scores', scores)
        check_result(model, train_loader, optim, tokenizer)
        torch.save(model.state_dict(), 'm2_cnn_critical_' + str(e) + '.pt')

#start_train_critical()

In [14]:
def start_train():
    for e in range(2, 3, 1):
        train_loss = train_fn(e, model, train_loader, optim, tokenizer)
        val_loss = evaluate_loss(model, valid_loader, loss_fn, tokenizer)
        scores = evaluate_metrics(model, valid_loader, tokenizer, device)
        print('train_loss', train_loss) 
        print('val_loss', val_loss)
        print('scores', scores)
        check_result(model, train_loader, optim, tokenizer)
        torch.save(model.state_dict(), 'm2_cnn2_' + str(e) + '.pt')

#start_train()

In [15]:
#torch.save(model.state_dict(), 'm2_cnn2_1_2.pt')

In [16]:
# val_loss = evaluate_loss(model, valid_loader, loss_fn, tokenizer)
# print('val_loss', val_loss)

In [17]:
# scores = evaluate_metrics(model, valid_loader, tokenizer, device)
# print('scores', scores)

# Submission

In [18]:
# submission

def save2csv():
    test = pd.read_csv('../data/sample_submission.csv')

    def get_test_file_path(image_id):
        return "../data/test/{}/{}/{}/{}.png".format(
            image_id[0], image_id[1], image_id[2], image_id 
        )

    test['file_path'] = test['image_id'].apply(get_test_file_path)

    print(f'test.shape: {test.shape}')
    test_dataset = TestDataset(test, transform=get_transforms(data='valid'))
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=num_workers)
    predictions = get_csv(model, test_loader, optim, tokenizer)
    predictions = np.array([sublist for sublist in predictions])
    #[predictions] = predictions
    # print(type(predictions))
    # print(predictions.shape)
    # print(len(predictions))
    # print(predictions)
    test['InChI'] = [f"InChI=1S/{text}" for text in predictions]
    test[['image_id', 'InChI']].to_csv('submission.csv', index=False)
    test[['image_id', 'InChI']].head()

save2csv()

test.shape: (1616107, 3)
Output csv:   0%|          | 2/6313 [00:58<50:52:49, 29.02s/it]
<class 'numpy.ndarray'>
(768,)
768
['C10H14BrN5S/c1-6-9(11)10(16(3)15-6)4-7(12-2)8-5-17-14-13-8/h5,7,12H,4H2,1-3H3'
 'C15H18ClN3/c16-13-5-3-11(4-6-13)8-17-9-14-7-15(19-14)18-10-12-1-2-12/h3-7,12,17-18H,1-2,8-10H2'
 'C16H13BrN2O/c1-11(20)13-6-7-15(17)16(8-13)19-10-12-4-2-3-5-14(12)9-18/h2-8,19H,10H2,1H3'
 'C14H19FN4O/c1-14(2,3)12-11(13(16)17-18-19(12)4)8-9-5-6-10(15)7-9/h5-7H,8H2,1-4H3,(H2,16,17)'
 'C9H12N2O2/c1-3-4-2-5(6(4)10)9(12)7(3)11-8(9)13-7/h3-7H,2,10H2,1H3/t3-,4-,5-,6-,7-,9-/m1/s1'
 'C42H46O/c1-5-7-9-39(3,4)35-17-13-33(14-18-35)37-31-15-11-29(12-16-31)27-41(6-2,7-8-10-27)40-28-30-19-21-32(22-20-30)23-34(24-32)25-26-36(40)38(33)36/h11-14,29-34H,1-4H3'
 'C16H28N2/c1-5-11-17-16(12-13(3)4)15-10-8-9-14(6-2)18-15/h8-10,13,16-17H,5-7,11-12H2,1-4H3'
 'C26H29Cl2N5O6S/c1-6-38-25(36)21-13(3)30-26(39-7-2)31-22(21)24-32-17-11-10-14(27)12-15(17)23(33-24)28-18-9-8-16(37-4)19(18)29-20(34)12-35/h8-12,23H,6-7

ValueError: Length of values (768) does not match length of index (1616107)

test.shape: (1616107, 3)
Output csv:   0%|          | 6/12626 [01:49<64:04:02, 18.28s/it]
[['C10H14BrN5S/c1-6-8(16(3)15-14-6)4-9(12-2)7-5-17-10(11)13-7/h5,9,12H,4H2,1-3H3', 'C15H18ClN3/c16-13-5-3-11(4-6-13)8-17-9-14-7-15(19-14)18-10-12-1-2-12/h3-7,12,17-18H,1-2,8-10H2', 'C16H13BrN2O/c1-11(20)13-6-7-15(17)16(8-13)19-10-12-4-2-3-5-14(12)9-18/h2-8,19H,10H2,1H3', 'C14H19FN4O/c1-14(2,3)12-11(13(16)17-18-19(12)4)8-9-5-6-10(15)7-9/h5-7H,8H2,1-4H3,(H2,16,17)', 'C9H12N2O2/c1-3-4-2-5-6(7(4)10)8(12)11-9(3)13-5/h3-8H,2,10H2,1H3,(H,11,12)/t3-,4-,5-,6-,7-,8-/m1/s1', 'C42H46O/c1-5-7-9-39(3,4)35-17-19-41(20-18-35,21-10-8-6-2)36-22-24-40(25-23-36)37-26-29-13-11-30(12-14-29)27-31-14-15-32(16-31)28-33(29)34(37)32/h6,8,10,29-34H,5,7,9,11-28H2,1-4H3', 'C16H28N2/c1-5-11-17-16(12-13(3)4)15-10-8-9-14(6-2)18-15/h8-10,13,16-17H,5-7,11-12H2,1-4H3', 'C26H29Cl2N5O6S/c1-6-38-25(36)21-13(3)30-26(39-7-2)31-18-11-10-14(27)12-15(18)22(21)32-24(34)19-16(28)9-8-13(20(19)37-4)23(35)33-40-5/h8-12,22H,6-7H2,1-5H3,(H,33,34)(

ValueError: Length of values (896) does not match length of index (1616107)