In [1]:
cd ../../

/pless_nfs/home/zeyu/github/gwuvision/reverse-pheno


In [2]:
import torch, os
import pandas as pd
from genetic_marker_dataset import EmbeddingGeneDataset

In [3]:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm.auto import tqdm
from torch import nn
import torch.nn.functional as F

In [4]:
marker_family = {'leaf_wax':['sobic_001G269200_1_51588525',
                             'sobic_001G269200_1_51588838',
                             'sobic_001G269200_1_51589143',
                             'sobic_001G269200_1_51589435',],
                 'dw':      ['sobic_006G067700_1_42805319',
                             'sobic_006G067700_1_42804037'
                             ],
                 'd_locus': ['sobic_006G147400_1_50898459',
                             'sobic_006G147400_1_50898536',
                             'sobic_006G147400_1_50898315',
                             'sobic_006G147400_1_50898231',
                             'sobic_006G147400_1_50898523',
                             'sobic_006G147400_1_50898525',],
                 'ma':      ['sobic_006G057866_1_40312463',
                             'sobic_006G004400_2_2697734'],
                 'tan':     ['sobic_009G229800_1_57040680']}

In [5]:
known_genetic_markers = [i for k, v in marker_family.items() for i in v]

In [6]:
marker_family_df = pd.DataFrame([(k, v) for k, vs in marker_family.items() for v in vs], columns=['family', 'gene'])

In [7]:
def get_sampler(ds):
    class_counts = torch.bincount(torch.tensor(ds.marker_df['label']))

    # Compute the class weights
    class_weights = [count / len(ds) for count in class_counts]
    class_weights = 1 / torch.tensor(class_weights)
    class_weights = class_weights.double()
    class_weights = class_weights / class_weights.sum()

    # Create the sampler
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=[class_weights[label] for label in ds.marker_df['label']],
        num_samples=len(ds),
        replacement=True
    )
    return sampler

In [8]:
def single_sensor_gene_pred(gene_list, sensor, noise_var=0.00):
    ebd_result = torch.load(f'results/snp_pred/s9_jpg_{sensor}_gene_ds_img_ebd.pth')
    all_ebd = ebd_result['ebd_128']
    num_train_samples_list = []
    num_test_samples_list = []
    train_acc_list = []
    test_acc_list = []
    test_mode_acc_list = []
    pred_dict = {}
    all_ebd = all_ebd.cuda()

    for marker in gene_list:
        ebd_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', marker, sensor, train=True)
        train_sampler = get_sampler(ebd_ds)
        ebd_dl = torch.utils.data.DataLoader(ebd_ds, batch_size=256, sampler=train_sampler)
        ebd_test_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', marker, sensor, train=False)
        ebd_test_dl = torch.utils.data.DataLoader(ebd_test_ds, batch_size=256, shuffle=False)
        print(f'gene: {marker}')
        num_train_samples_list.append(len(ebd_ds))
        num_test_samples_list.append(len(ebd_test_ds))
        gene_pred_fc = nn.Sequential(#nn.Dropout(0.5),
                                     nn.Linear(128, 2))
        gene_pred_fc.cuda()
        loss_func = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(gene_pred_fc.parameters(), lr=0.01, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        # train
        gene_pred_fc.train()
        for ep in range(3):
            tbar = tqdm(ebd_dl, leave=False)
            total_loss = 0
            for i, (ebd_index, label, cultivar) in enumerate(tbar):
                optimizer.zero_grad()
                ebd = all_ebd[ebd_index.cuda()]
                ebd = F.normalize(ebd)
                noise = torch.normal(0, noise_var, size=ebd.shape).cuda()
                gene_pred = gene_pred_fc(ebd+noise)
                loss = loss_func(gene_pred, label.cuda())
                total_loss += loss.item()
                loss.backward()
                optimizer.step()
                if i % 20 == 0:
                    tbar.set_description(f'loss: {total_loss/20:.4f}')
                    total_loss = 0
            lr_scheduler.step()
        # train acc
        gene_pred_fc.eval()
        tbar = tqdm(ebd_dl, leave=False)
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for ebd_index, label, cultivar in tbar:
                ebd = all_ebd[ebd_index.cuda()]
                ebd = F.normalize(ebd)
                gene_pred = gene_pred_fc(ebd)
                all_preds.append(gene_pred.cpu())
                all_labels.append(label.cpu())
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        train_acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
        train_acc_list.append(train_acc)
        # test
        tbar = tqdm(ebd_test_dl, leave=False)
        all_preds = []
        all_labels = []
        all_cultivars = []
        with torch.no_grad():
            for ebd_index, label, cultivar in tbar:
                ebd = all_ebd[ebd_index]
                ebd = F.normalize(ebd)
                gene_pred = gene_pred_fc(ebd)
                all_preds.append(gene_pred.cpu())
                all_labels.append(label.cpu())
                all_cultivars.append(cultivar)
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_cultivars = [j for i in all_cultivars for j in i]
        acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
        pred_df = pd.DataFrame({'pred': all_preds.argmax(dim=1).numpy(), 'gt': all_labels.numpy(), 'cultivar': all_cultivars})
        cultivar_mode_df = pred_df.groupby('cultivar').agg(pd.Series.mode)
        cultivar_mode_acc = (cultivar_mode_df['pred'] == cultivar_mode_df['gt']).mean()
        test_acc_list.append(acc)
        test_mode_acc_list.append(cultivar_mode_acc)
        pred_dict[marker] = pred_df
        save_path = f'result/snp_pred/{marker}_{sensor}_pred.pt'
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        torch.save(gene_pred_fc.state_dict(), save_path)
        print(f'gene: {marker}, train_imgs: {len(ebd_ds)}, test_imgs: {len(ebd_test_ds)}, train_acc:{train_acc:.4f}, \
              acc: {acc:.4f}, cultivar_acc: {cultivar_mode_acc:.4f}')
    result_df = pd.DataFrame({'gene': gene_list, 
                              'num_train_samples': num_train_samples_list, 
                              'num_test_samples': num_test_samples_list, 
                              'train_acc': [i.item() for i in train_acc_list], 
                              'test_acc': [i.item() for i in test_acc_list],
                              'cultivar_mode_acc': test_mode_acc_list})
    result_df = result_df.merge(marker_family_df, on='gene')
    result_df = result_df[['family', 'gene', 'num_train_samples', 'num_test_samples', 'train_acc', 'test_acc', 'cultivar_mode_acc']]
    return result_df, pred_dict

In [9]:
rgb_result_df, rgb_pred_dict = single_sensor_gene_pred(known_genetic_markers, 'rgb')

gene: sobic_001G269200_1_51588525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1091.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1091.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1091.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1091.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=323.0), HTML(value='')))

gene: sobic_001G269200_1_51588525, train_imgs: 279151, test_imgs: 82644, train_acc:0.6532,               acc: 0.6146, cultivar_acc: 0.6400
gene: sobic_001G269200_1_51588838


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1005.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1005.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1005.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1005.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=290.0), HTML(value='')))

gene: sobic_001G269200_1_51588838, train_imgs: 257180, test_imgs: 74100, train_acc:0.6336,               acc: 0.6015, cultivar_acc: 0.6522
gene: sobic_001G269200_1_51589143


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1049.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1049.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1049.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1049.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=437.0), HTML(value='')))

gene: sobic_001G269200_1_51589143, train_imgs: 268544, test_imgs: 111672, train_acc:0.6376,               acc: 0.6219, cultivar_acc: 0.6618
gene: sobic_001G269200_1_51589435


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1247.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1247.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1247.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1247.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=210.0), HTML(value='')))

gene: sobic_001G269200_1_51589435, train_imgs: 319136, test_imgs: 53620, train_acc:0.6505,               acc: 0.6399, cultivar_acc: 0.7812
gene: sobic_006G067700_1_42805319


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1217.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1217.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1217.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1217.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=488.0), HTML(value='')))

gene: sobic_006G067700_1_42805319, train_imgs: 311427, test_imgs: 124734, train_acc:0.6220,               acc: 0.6045, cultivar_acc: 0.6538
gene: sobic_006G067700_1_42804037


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1166.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1166.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1166.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1166.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=461.0), HTML(value='')))

gene: sobic_006G067700_1_42804037, train_imgs: 298316, test_imgs: 117852, train_acc:0.6207,               acc: 0.5970, cultivar_acc: 0.6528
gene: sobic_006G147400_1_50898459


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1252.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1252.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1252.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1252.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=228.0), HTML(value='')))

gene: sobic_006G147400_1_50898459, train_imgs: 320272, test_imgs: 58340, train_acc:0.6091,               acc: 0.5388, cultivar_acc: 0.5526
gene: sobic_006G147400_1_50898536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1218.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1218.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1218.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1218.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=246.0), HTML(value='')))

gene: sobic_006G147400_1_50898536, train_imgs: 311582, test_imgs: 62912, train_acc:0.6006,               acc: 0.5050, cultivar_acc: 0.5000
gene: sobic_006G147400_1_50898315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1075.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1075.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1075.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1075.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=249.0), HTML(value='')))

gene: sobic_006G147400_1_50898315, train_imgs: 275004, test_imgs: 63664, train_acc:0.5963,               acc: 0.5272, cultivar_acc: 0.5526
gene: sobic_006G147400_1_50898231


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=929.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=929.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=929.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=929.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=209.0), HTML(value='')))

gene: sobic_006G147400_1_50898231, train_imgs: 237704, test_imgs: 53468, train_acc:0.5891,               acc: 0.6171, cultivar_acc: 0.7353
gene: sobic_006G147400_1_50898523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1216.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1216.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1216.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1216.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=235.0), HTML(value='')))

gene: sobic_006G147400_1_50898523, train_imgs: 311172, test_imgs: 59966, train_acc:0.5881,               acc: 0.5776, cultivar_acc: 0.5526
gene: sobic_006G147400_1_50898525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1213.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1213.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1213.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1213.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=240.0), HTML(value='')))

gene: sobic_006G147400_1_50898525, train_imgs: 310297, test_imgs: 61254, train_acc:0.5734,               acc: 0.6252, cultivar_acc: 0.7632
gene: sobic_006G057866_1_40312463


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1656.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1656.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1656.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1656.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=132.0), HTML(value='')))

gene: sobic_006G057866_1_40312463, train_imgs: 423910, test_imgs: 33706, train_acc:0.6591,               acc: 0.6704, cultivar_acc: 0.8500
gene: sobic_006G004400_2_2697734


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=805.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=805.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=805.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=805.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=237.0), HTML(value='')))

gene: sobic_006G004400_2_2697734, train_imgs: 205872, test_imgs: 60636, train_acc:0.6412,               acc: 0.6235, cultivar_acc: 0.7632
gene: sobic_009G229800_1_57040680


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1445.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1445.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1445.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1445.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=353.0), HTML(value='')))

gene: sobic_009G229800_1_57040680, train_imgs: 369672, test_imgs: 90132, train_acc:0.6300,               acc: 0.5868, cultivar_acc: 0.6111


In [10]:
d3_result_df, d3_pred_dict = single_sensor_gene_pred(known_genetic_markers, '3d')

gene: sobic_001G269200_1_51588525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1115.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1115.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1115.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1115.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=316.0), HTML(value='')))

gene: sobic_001G269200_1_51588525, train_imgs: 285259, test_imgs: 80892, train_acc:0.6019,               acc: 0.5687, cultivar_acc: 0.7000
gene: sobic_001G269200_1_51588838


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1024.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1024.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1024.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1024.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=295.0), HTML(value='')))

gene: sobic_001G269200_1_51588838, train_imgs: 262127, test_imgs: 75394, train_acc:0.5940,               acc: 0.5642, cultivar_acc: 0.6957
gene: sobic_001G269200_1_51589143


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1067.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1067.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1067.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1067.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=440.0), HTML(value='')))

gene: sobic_001G269200_1_51589143, train_imgs: 273089, test_imgs: 112610, train_acc:0.6021,               acc: 0.5971, cultivar_acc: 0.6618
gene: sobic_001G269200_1_51589435


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1262.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1262.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1262.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1262.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=217.0), HTML(value='')))

gene: sobic_001G269200_1_51589435, train_imgs: 322863, test_imgs: 55332, train_acc:0.6261,               acc: 0.6064, cultivar_acc: 0.7500
gene: sobic_006G067700_1_42805319


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1224.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1224.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1224.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1224.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))

gene: sobic_006G067700_1_42805319, train_imgs: 313286, test_imgs: 127898, train_acc:0.5924,               acc: 0.5842, cultivar_acc: 0.6538
gene: sobic_006G067700_1_42804037


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1181.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1181.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1181.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1181.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=471.0), HTML(value='')))

gene: sobic_006G067700_1_42804037, train_imgs: 302262, test_imgs: 120344, train_acc:0.5934,               acc: 0.5708, cultivar_acc: 0.6806
gene: sobic_006G147400_1_50898459


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1261.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1261.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1261.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1261.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=231.0), HTML(value='')))

gene: sobic_006G147400_1_50898459, train_imgs: 322659, test_imgs: 59056, train_acc:0.5920,               acc: 0.5362, cultivar_acc: 0.6053
gene: sobic_006G147400_1_50898536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=256.0), HTML(value='')))

gene: sobic_006G147400_1_50898536, train_imgs: 314192, test_imgs: 65414, train_acc:0.5774,               acc: 0.5689, cultivar_acc: 0.6842
gene: sobic_006G147400_1_50898315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1085.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1085.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1085.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1085.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=254.0), HTML(value='')))

gene: sobic_006G147400_1_50898315, train_imgs: 277623, test_imgs: 64800, train_acc:0.5777,               acc: 0.5837, cultivar_acc: 0.6579
gene: sobic_006G147400_1_50898231


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=930.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=930.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=930.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=930.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=215.0), HTML(value='')))

gene: sobic_006G147400_1_50898231, train_imgs: 237907, test_imgs: 54942, train_acc:0.5717,               acc: 0.5779, cultivar_acc: 0.6176
gene: sobic_006G147400_1_50898523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1220.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1220.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1220.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1220.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=238.0), HTML(value='')))

gene: sobic_006G147400_1_50898523, train_imgs: 312081, test_imgs: 60792, train_acc:0.5848,               acc: 0.5676, cultivar_acc: 0.5789
gene: sobic_006G147400_1_50898525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=237.0), HTML(value='')))

gene: sobic_006G147400_1_50898525, train_imgs: 314401, test_imgs: 60658, train_acc:0.5923,               acc: 0.5756, cultivar_acc: 0.6579
gene: sobic_006G057866_1_40312463


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1673.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1673.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1673.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1673.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=137.0), HTML(value='')))

gene: sobic_006G057866_1_40312463, train_imgs: 428177, test_imgs: 34958, train_acc:0.6198,               acc: 0.5432, cultivar_acc: 0.6500
gene: sobic_006G004400_2_2697734


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=806.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=806.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=806.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=806.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=247.0), HTML(value='')))

gene: sobic_006G004400_2_2697734, train_imgs: 206256, test_imgs: 63080, train_acc:0.6450,               acc: 0.6304, cultivar_acc: 0.8158
gene: sobic_009G229800_1_57040680


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1467.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1467.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1467.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1467.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=354.0), HTML(value='')))

gene: sobic_009G229800_1_57040680, train_imgs: 375433, test_imgs: 90462, train_acc:0.5933,               acc: 0.5899, cultivar_acc: 0.7037


In [11]:
pd.options.display.float_format = '{:,.3f}'.format

In [12]:
rgb_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,cultivar_mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,279151,82644,0.653,0.615,0.64
1,leaf_wax,sobic_001G269200_1_51588838,257180,74100,0.634,0.601,0.652
2,leaf_wax,sobic_001G269200_1_51589143,268544,111672,0.638,0.622,0.662
3,leaf_wax,sobic_001G269200_1_51589435,319136,53620,0.65,0.64,0.781
4,dw,sobic_006G067700_1_42805319,311427,124734,0.622,0.604,0.654
5,dw,sobic_006G067700_1_42804037,298316,117852,0.621,0.597,0.653
6,d_locus,sobic_006G147400_1_50898459,320272,58340,0.609,0.539,0.553
7,d_locus,sobic_006G147400_1_50898536,311582,62912,0.601,0.505,0.5
8,d_locus,sobic_006G147400_1_50898315,275004,63664,0.596,0.527,0.553
9,d_locus,sobic_006G147400_1_50898231,237704,53468,0.589,0.617,0.735


In [13]:
d3_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,cultivar_mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,285259,80892,0.602,0.569,0.7
1,leaf_wax,sobic_001G269200_1_51588838,262127,75394,0.594,0.564,0.696
2,leaf_wax,sobic_001G269200_1_51589143,273089,112610,0.602,0.597,0.662
3,leaf_wax,sobic_001G269200_1_51589435,322863,55332,0.626,0.606,0.75
4,dw,sobic_006G067700_1_42805319,313286,127898,0.592,0.584,0.654
5,dw,sobic_006G067700_1_42804037,302262,120344,0.593,0.571,0.681
6,d_locus,sobic_006G147400_1_50898459,322659,59056,0.592,0.536,0.605
7,d_locus,sobic_006G147400_1_50898536,314192,65414,0.577,0.569,0.684
8,d_locus,sobic_006G147400_1_50898315,277623,64800,0.578,0.584,0.658
9,d_locus,sobic_006G147400_1_50898231,237907,54942,0.572,0.578,0.618


In [14]:
def multimodal_mode(rgb_df, d3_df):
    all_pred = pd.concat([rgb_df, d3_df])
    all_pred_mode = all_pred.groupby('cultivar').agg(pd.Series.mode)
    return (all_pred_mode['pred'] == all_pred_mode['gt']).mean()

In [15]:
for marker in rgb_pred_dict.keys():
    print(f'{marker} {multimodal_mode(rgb_pred_dict[marker], d3_pred_dict[marker]):.3f}')

sobic_001G269200_1_51588525 0.640
sobic_001G269200_1_51588838 0.674
sobic_001G269200_1_51589143 0.662
sobic_001G269200_1_51589435 0.750
sobic_006G067700_1_42805319 0.667
sobic_006G067700_1_42804037 0.639
sobic_006G147400_1_50898459 0.579
sobic_006G147400_1_50898536 0.579
sobic_006G147400_1_50898315 0.632
sobic_006G147400_1_50898231 0.735
sobic_006G147400_1_50898523 0.658
sobic_006G147400_1_50898525 0.763
sobic_006G057866_1_40312463 0.800
sobic_006G004400_2_2697734 0.789
sobic_009G229800_1_57040680 0.630


# Fusion

In [16]:
rgb_ebd_result = torch.load('results/snp_pred/s9_jpg_rgb_gene_ds_img_ebd.pth')
d3_edb_result = torch.load('results/snp_pred/s9_jpg_3d_gene_ds_img_ebd.pth')

In [18]:
rgb_ebd = rgb_ebd_result['ebd_128']
d3_ebd = d3_edb_result['ebd_128']

In [21]:
import importlib
import genetic_marker_dataset
importlib.reload(genetic_marker_dataset)
from genetic_marker_dataset import EmbeddingGeneDataset

In [29]:
num_train_samples_list = []
num_test_samples_list = []
train_acc_list = []
test_acc_list = []
test_mode_acc_list = []
pred_dict = {}
rgb_ebd = rgb_ebd.cuda()
d3_ebd = d3_ebd.cuda()


for markers in known_genetic_markers:
    ebd_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', markers, 'multimodal', train=True)
    train_sampler = get_sampler(ebd_ds)
    ebd_dl = torch.utils.data.DataLoader(ebd_ds, batch_size=256, sampler=train_sampler)
    ebd_test_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', markers, 'multimodal', train=False)
    ebd_test_dl = torch.utils.data.DataLoader(ebd_test_ds, batch_size=256, shuffle=False)
    print(f'gene: {markers}')
    num_train_samples_list.append(len(ebd_ds))
    num_test_samples_list.append(len(ebd_test_ds))
    gene_pred_fc = nn.Sequential(#nn.Dropout(0.5),
                                 nn.Linear(256, 2))
    gene_pred_fc.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(gene_pred_fc.parameters(), lr=0.01, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    # train
    gene_pred_fc.train()
    for ep in range(3):
        tbar = tqdm(ebd_dl, leave=False)
        total_loss = 0
        for i, (rgb_ebd_index, d3_ebd_index, label, cultivar) in enumerate(tbar):
            optimizer.zero_grad()
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            loss = loss_func(gene_pred, label.cuda())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            if i % 20 == 0:
                tbar.set_description(f'loss: {total_loss/20:.4f}')
                total_loss = 0
        lr_scheduler.step()
    # train acc
    gene_pred_fc.eval()
    tbar = tqdm(ebd_dl, leave=False)
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for rgb_ebd_index, d3_ebd_index, label, cultivar in tbar:
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            all_preds.append(gene_pred.cpu())
            all_labels.append(label.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    train_acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
    train_acc_list.append(train_acc)
    # test
    tbar = tqdm(ebd_test_dl, leave=False)
    all_preds = []
    all_labels = []
    all_cultivars = []
    with torch.no_grad():
        for rgb_ebd_index, d3_ebd_index, label, cultivar in tbar:
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            all_preds.append(gene_pred.cpu())
            all_labels.append(label.cpu())
            all_cultivars.append(cultivar)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_cultivars = [j for i in all_cultivars for j in i]
    acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
    pred_df = pd.DataFrame({'pred': all_preds.argmax(dim=1).numpy(), 'gt': all_labels.numpy(), 'cultivar': all_cultivars})
    cultivar_mode_df = pred_df.groupby('cultivar').agg(pd.Series.mode)
    cultivar_mode_acc = (cultivar_mode_df['pred'] == cultivar_mode_df['gt']).mean()
    test_acc_list.append(acc)
    test_mode_acc_list.append(cultivar_mode_acc)
    pred_dict[marker] = pred_df
    save_path = f'result/snp_pred/{markers}_mm_pred.pt'
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    torch.save(gene_pred_fc.state_dict(), save_path)
    print(f'gene: {markers}, train_imgs: {len(ebd_ds)}, test_imgs: {len(ebd_test_ds)}, train_acc:{train_acc:.4f}, acc: {acc:.4f}')

gene: sobic_001G269200_1_51588525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1084.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1084.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1084.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1084.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=306.0), HTML(value='')))

gene: sobic_001G269200_1_51588525, train_imgs: 277497, test_imgs: 78208, train_acc:0.6978, acc: 0.6258
gene: sobic_001G269200_1_51588838


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=993.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=993.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=993.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=993.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=289.0), HTML(value='')))

gene: sobic_001G269200_1_51588838, train_imgs: 254095, test_imgs: 73738, train_acc:0.6958, acc: 0.5743
gene: sobic_001G269200_1_51589143


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1037.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1037.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1037.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1037.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=425.0), HTML(value='')))

gene: sobic_001G269200_1_51589143, train_imgs: 265324, test_imgs: 108760, train_acc:0.6810, acc: 0.6295
gene: sobic_001G269200_1_51589435


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1228.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=212.0), HTML(value='')))

gene: sobic_001G269200_1_51589435, train_imgs: 314150, test_imgs: 54082, train_acc:0.7040, acc: 0.6518
gene: sobic_006G067700_1_42805319


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1192.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1192.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1192.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1192.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=484.0), HTML(value='')))

gene: sobic_006G067700_1_42805319, train_imgs: 305050, test_imgs: 123898, train_acc:0.6597, acc: 0.6266
gene: sobic_006G067700_1_42804037


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1149.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1149.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1149.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1149.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=456.0), HTML(value='')))

gene: sobic_006G067700_1_42804037, train_imgs: 293891, test_imgs: 116642, train_acc:0.6667, acc: 0.6082
gene: sobic_006G147400_1_50898459


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1229.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=222.0), HTML(value='')))

gene: sobic_006G147400_1_50898459, train_imgs: 314483, test_imgs: 56728, train_acc:0.6681, acc: 0.5716
gene: sobic_006G147400_1_50898536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1193.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1193.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1193.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1193.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=251.0), HTML(value='')))

gene: sobic_006G147400_1_50898536, train_imgs: 305214, test_imgs: 64230, train_acc:0.6584, acc: 0.5832
gene: sobic_006G147400_1_50898315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1052.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1052.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1052.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1052.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=248.0), HTML(value='')))

gene: sobic_006G147400_1_50898315, train_imgs: 269254, test_imgs: 63418, train_acc:0.6550, acc: 0.5946
gene: sobic_006G147400_1_50898231


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=903.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=903.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=903.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=903.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=207.0), HTML(value='')))

gene: sobic_006G147400_1_50898231, train_imgs: 231131, test_imgs: 52970, train_acc:0.6580, acc: 0.6393
gene: sobic_006G147400_1_50898523


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1187.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1187.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1187.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1187.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=229.0), HTML(value='')))

gene: sobic_006G147400_1_50898523, train_imgs: 303725, test_imgs: 58540, train_acc:0.6580, acc: 0.6211
gene: sobic_006G147400_1_50898525


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1197.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1197.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1197.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1197.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=229.0), HTML(value='')))

gene: sobic_006G147400_1_50898525, train_imgs: 306231, test_imgs: 58378, train_acc:0.6489, acc: 0.6379
gene: sobic_006G057866_1_40312463


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1628.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1628.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1628.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1628.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=133.0), HTML(value='')))

gene: sobic_006G057866_1_40312463, train_imgs: 416604, test_imgs: 34004, train_acc:0.7347, acc: 0.6404
gene: sobic_006G004400_2_2697734


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=783.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=783.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=783.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=783.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=238.0), HTML(value='')))

gene: sobic_006G004400_2_2697734, train_imgs: 200431, test_imgs: 60686, train_acc:0.7496, acc: 0.6956
gene: sobic_009G229800_1_57040680


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1428.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1428.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1428.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1428.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=342.0), HTML(value='')))

gene: sobic_009G229800_1_57040680, train_imgs: 365353, test_imgs: 87452, train_acc:0.6807, acc: 0.6299


In [39]:
result_df = pd.DataFrame({'gene': known_genetic_markers, 
                          'num_train_samples': num_train_samples_list, 
                          'num_test_samples': num_test_samples_list, 
                          'train_acc': [i.item() for i in train_acc_list], 
                          'test_acc': [i.item() for i in test_acc_list],
                          'mode_acc': test_mode_acc_list})

In [40]:
result_df = result_df.merge(marker_family_df, on='gene')

In [42]:
pd.options.display.float_format = '{:,.3f}'.format
multimodal_result_df = result_df[['family', 'gene', 'num_train_samples', 'num_test_samples', 'train_acc', 'test_acc', 'mode_acc']]
multimodal_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,277497,78208,0.698,0.626,0.7
1,leaf_wax,sobic_001G269200_1_51588838,254095,73738,0.696,0.574,0.587
2,leaf_wax,sobic_001G269200_1_51589143,265324,108760,0.681,0.63,0.676
3,leaf_wax,sobic_001G269200_1_51589435,314150,54082,0.704,0.652,0.75
4,dw,sobic_006G067700_1_42805319,305050,123898,0.66,0.627,0.756
5,dw,sobic_006G067700_1_42804037,293891,116642,0.667,0.608,0.681
6,d_locus,sobic_006G147400_1_50898459,314483,56728,0.668,0.572,0.632
7,d_locus,sobic_006G147400_1_50898536,305214,64230,0.658,0.583,0.658
8,d_locus,sobic_006G147400_1_50898315,269254,63418,0.655,0.595,0.711
9,d_locus,sobic_006G147400_1_50898231,231131,52970,0.658,0.639,0.765
