In [1]:
from src.model import SPMM, EndToEndModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

In [None]:
class AROMMA(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.load_model()

    def load_model(self):
        self.embedder = SPMM(r=4, lora_alpha=8, inference=True)
        self.model = EndToEndModel(embedder=self.embedder, sattn_hidden_dim=196, cattn_hidden_dim=384, num_heads=4, num_labels=152).to(self.device)
        self.model.load_state_dict(torch.load("aromma_best_fold.pt", map_location=self.device))
        self.model.eval()
    
    def forward(self, smiles):
        with torch.no_grad():
            logit = self.model.forward(smiles)
        return logit
embedder = AROMMA()

Class Distirubtion Aware Threshold

In [3]:
df_gslf_arr, df_bp_arr = [], []
for i in range(1, 6):
    for dname in ["train", "valid", "test"]:
        df_mix = pd.read_csv(f"data/mixture/fold{i}/{dname}.csv")
        mask = df_mix["smiles"].str.contains(";")
        df_gslf = df_mix[~mask]
        df_bp = df_mix[mask]
        df_gslf_arr.append(df_gslf)
        df_bp_arr.append(df_bp)

df_gslf_total = pd.concat(df_gslf_arr).drop_duplicates("smiles").reset_index(drop=True)
df_bp_total = pd.concat(df_bp_arr).drop_duplicates("smiles").reset_index(drop=True)

In [4]:
# about all 152 labels
gslf_label_ratio = df_gslf_total.iloc[:, 1:].sum()/len(df_gslf_total)
gslf_label_ratio

acidic        0.008725
alcoholic     0.021396
aldehydic     0.029705
alliaceous    0.026174
almond        0.017657
                ...   
warm          0.019942
waxy          0.100748
weedy         0.009140
winey         0.045700
woody         0.140216
Length: 152, dtype: float64

In [5]:
pair_data = df_bp_total["smiles"].values
n = len(pair_data)

prob_arr = []
num_chunks = n // 3000
chunk_size = n // num_chunks
for j in tqdm(range(num_chunks)):
    start = j * chunk_size
    end = (j + 1) * chunk_size if j < num_chunks - 1 else n
    chunk = pair_data[start:end]

    prob = F.sigmoid(embedder.forward(chunk).detach().cpu())
    prob_arr.append(prob) 
probs = torch.cat(prob_arr, dim=0).numpy() # (N, 152)

100%|██████████| 46/46 [00:58<00:00,  1.28s/it]


In [6]:
with open("threshold.txt", "w") as f:
    for i, (label, ratio) in enumerate(gslf_label_ratio.items()):
        prob_i = probs[:, i]
        thresh = np.quantile(prob_i, 1 - ratio)
        f.write(f"{thresh:.4f}\n")
        print(f"{label}: {thresh:.4f}")

acidic: 0.0466
alcoholic: 0.0052
aldehydic: 0.2651
alliaceous: 0.2543
almond: 0.0134
amber: 0.0693
animal: 0.0424
anisic: 0.0131
apple: 0.0078
apricot: 0.0216
aromatic: 0.0440
balsamic: 0.3256
banana: 0.0076
beefy: 0.0280
bergamot: 0.0146
berry: 0.0599
bitter: 0.0193
black currant: 0.0128
brandy: 0.0347
bready: 0.0072
brown: 0.0298
burnt: 0.0206
buttery: 0.0931
cabbage: 0.0138
camphoreous: 0.0690
caramellic: 0.0912
cedar: 0.0445
celery: 0.0092
chamomile: 0.0088
cheesy: 0.0137
chemical: 0.0148
cherry: 0.0187
chocolate: 0.0645
cinnamon: 0.0119
citrus: 0.0979
clean: 0.0265
clove: 0.0237
cocoa: 0.0202
coconut: 0.0679
coffee: 0.1165
cognac: 0.0132
cooked: 0.0301
cooling: 0.0739
cortex: 0.0149
coumarinic: 0.0169
creamy: 0.1122
cucumber: 0.0036
dairy: 0.0231
dry: 0.0174
earthy: 0.0574
estery: 0.0915
ethereal: 0.0748
fatty: 0.1181
fermented: 0.0502
fishy: 0.0107
floral: 0.6135
fresh: 0.0106
fruit skin: 0.0116
fruity: 0.4247
fungal: 0.0823
fusel: 0.0218
garlic: 0.0129
gassy: 0.0107
geranium: 0.

Pseudo Labelling

In [7]:
with open("data/labels_152.txt", "r") as f:
    total_152 = f.read().splitlines()
with open("data/bp_74.txt", "r") as f:
    common_74 = f.read().splitlines()
    
unknown_label = sorted(set(total_152)-set(common_74))
unknown_label_idx = [total_152.index(lbl) for lbl in unknown_label]
assert len(unknown_label) == 78

with open("threshold.txt", "r") as f:
    threshold_arr = f.read().splitlines()
threshold_arr = [float(x) for x in threshold_arr]
thresholds = torch.tensor(threshold_arr).view(1, -1) # (1, 152)

In [None]:
for i in tqdm(range(1, 6)):
    save_dir1 = f"data/mixture_p78/fold{i}"
    save_dir2 = f"data/mixture_p152/fold{i}"

    os.makedirs(save_dir1, exist_ok=True)
    os.makedirs(save_dir2, exist_ok=True)

    df_train = pd.read_csv(f"data/mixture/fold{i}/train.csv")
    df_valid = pd.read_csv(f"data/mixture/fold{i}/valid.csv")
    df_test = pd.read_csv(f"data/mixture/fold{i}/test.csv")

    for dname, df in [("train", df_train), ("valid", df_valid), ("test", df_test)]:
        mask = df["smiles"].str.contains(";")
        df_mix = df[mask] 
        df_mol = df[~mask]

        smiles_list = df_mix["smiles"].tolist()
        n = len(smiles_list)

        pred_arr = []
        num_chunks = 10 
        chunk_size = n // num_chunks
        for j in range(num_chunks):
            start = j * chunk_size
            end = (j + 1) * chunk_size if j < num_chunks - 1 else n
            chunk = smiles_list[start:end]

            logit = F.sigmoid(embedder.forward(chunk).detach().cpu())
            pred = (logit >= thresholds).int() # (chunk_size, 152)
            pred_arr.append(pred) 
        pred_arr = torch.cat(pred_arr, dim=0).numpy() # (N, 152)

        # update about 78 labels
        df_mix_up = df_mix.copy()
        df_mix_up.loc[:, unknown_label] = pred_arr[:, unknown_label_idx]
        df_made = pd.concat([df_mol, df_mix_up])
        df_made.to_csv(os.path.join(save_dir1, f"{dname}.csv"), index=False)

        # update about 152 labels
        df_mix_pred = df_mix.copy()
        df_mix_pred.iloc[:, 1:] = pred_arr
        union_mask = ((df_mix.iloc[:, 1:] | df_mix_pred.iloc[:, 1:]) > 0).astype(int) # there are no label occurence when just using pseudo labelling => or operator with ground truth
        df_mix_union = df_mix.copy()
        df_mix_union.iloc[:, 1:] = union_mask.values

        df_made_all = pd.concat([df_mol, df_mix_union])
        df_made_all.to_csv(os.path.join(save_dir2, f"{dname}.csv"), index=False)

100%|██████████| 5/5 [02:07<00:00, 25.60s/it]
