In [1]:
cd ../

In [2]:
import torch, os
import pandas as pd
from genetic_marker_dataset import EmbeddingGeneDataset

In [3]:
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from tqdm.auto import tqdm
from torch import nn
import torch.nn.functional as F

In [4]:
marker_family = {'leaf_wax':['sobic_001G269200_1_51588525',
                             'sobic_001G269200_1_51588838',
                             'sobic_001G269200_1_51589143',
                             'sobic_001G269200_1_51589435',],
                 'dw':      ['sobic_006G067700_1_42805319',
                             'sobic_006G067700_1_42804037'
                             ],
                 'd_locus': ['sobic_006G147400_1_50898459',
                             'sobic_006G147400_1_50898536',
                             'sobic_006G147400_1_50898315',
                             'sobic_006G147400_1_50898231',
                             'sobic_006G147400_1_50898523',
                             'sobic_006G147400_1_50898525',],
                 'ma':      ['sobic_006G057866_1_40312463',
                             'sobic_006G004400_2_2697734'],
                 'tan':     ['sobic_009G229800_1_57040680']}

In [5]:
known_genetic_markers = [i for k, v in marker_family.items() for i in v]

In [6]:
marker_family_df = pd.DataFrame([(k, v) for k, vs in marker_family.items() for v in vs], columns=['family', 'gene'])

In [7]:
def get_sampler(ds):
    class_counts = torch.bincount(torch.tensor(ds.marker_df['label']))

    # Compute the class weights
    class_weights = [count / len(ds) for count in class_counts]
    class_weights = 1 / torch.tensor(class_weights)
    class_weights = class_weights.double()
    class_weights = class_weights / class_weights.sum()

    # Create the sampler
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=[class_weights[label] for label in ds.marker_df['label']],
        num_samples=len(ds),
        replacement=True
    )
    return sampler

In [8]:
def single_sensor_gene_pred(gene_list, sensor, noise_var=0.00):
    ebd_result = torch.load(f'results/snp_pred/s9_jpg_{sensor}_gene_ds_img_ebd.pth')
    all_ebd = ebd_result['ebd_128']
    num_train_samples_list = []
    num_test_samples_list = []
    train_acc_list = []
    test_acc_list = []
    test_mode_acc_list = []
    pred_dict = {}
    all_ebd = all_ebd.cuda()

    for marker in gene_list:
        ebd_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', marker, sensor, train=True)
        train_sampler = get_sampler(ebd_ds)
        ebd_dl = torch.utils.data.DataLoader(ebd_ds, batch_size=256, sampler=train_sampler)
        ebd_test_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', marker, sensor, train=False)
        ebd_test_dl = torch.utils.data.DataLoader(ebd_test_ds, batch_size=256, shuffle=False)
        print(f'gene: {marker}')
        num_train_samples_list.append(len(ebd_ds))
        num_test_samples_list.append(len(ebd_test_ds))
        gene_pred_fc = nn.Sequential(#nn.Dropout(0.5),
                                     nn.Linear(128, 2))
        gene_pred_fc.cuda()
        loss_func = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(gene_pred_fc.parameters(), lr=0.01, momentum=0.9)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        # train
        gene_pred_fc.train()
        for ep in range(3):
            tbar = tqdm(ebd_dl, leave=False)
            total_loss = 0
            for i, (ebd_index, label, cultivar) in enumerate(tbar):
                optimizer.zero_grad()
                ebd = all_ebd[ebd_index.cuda()]
                ebd = F.normalize(ebd)
                noise = torch.normal(0, noise_var, size=ebd.shape).cuda()
                gene_pred = gene_pred_fc(ebd+noise)
                loss = loss_func(gene_pred, label.cuda())
                total_loss += loss.item()
                loss.backward()
                optimizer.step()
                if i % 20 == 0:
                    tbar.set_description(f'loss: {total_loss/20:.4f}')
                    total_loss = 0
            lr_scheduler.step()
        # train acc
        gene_pred_fc.eval()
        tbar = tqdm(ebd_dl, leave=False)
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for ebd_index, label, cultivar in tbar:
                ebd = all_ebd[ebd_index.cuda()]
                ebd = F.normalize(ebd)
                gene_pred = gene_pred_fc(ebd)
                all_preds.append(gene_pred.cpu())
                all_labels.append(label.cpu())
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        train_acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
        train_acc_list.append(train_acc)
        # test
        tbar = tqdm(ebd_test_dl, leave=False)
        all_preds = []
        all_labels = []
        all_cultivars = []
        with torch.no_grad():
            for ebd_index, label, cultivar in tbar:
                ebd = all_ebd[ebd_index]
                ebd = F.normalize(ebd)
                gene_pred = gene_pred_fc(ebd)
                all_preds.append(gene_pred.cpu())
                all_labels.append(label.cpu())
                all_cultivars.append(cultivar)
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_cultivars = [j for i in all_cultivars for j in i]
        acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
        pred_df = pd.DataFrame({'pred': all_preds.argmax(dim=1).numpy(), 'gt': all_labels.numpy(), 'cultivar': all_cultivars})
        cultivar_mode_df = pred_df.groupby('cultivar').agg(pd.Series.mode)
        cultivar_mode_acc = (cultivar_mode_df['pred'] == cultivar_mode_df['gt']).mean()
        test_acc_list.append(acc)
        test_mode_acc_list.append(cultivar_mode_acc)
        pred_dict[marker] = pred_df
        save_path = f'result/snp_pred/{marker}_{sensor}_pred.pt'
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        torch.save(gene_pred_fc.state_dict(), save_path)
        print(f'gene: {marker}, train_imgs: {len(ebd_ds)}, test_imgs: {len(ebd_test_ds)}, train_acc:{train_acc:.4f}, \
              acc: {acc:.4f}, cultivar_acc: {cultivar_mode_acc:.4f}')
    result_df = pd.DataFrame({'gene': gene_list, 
                              'num_train_samples': num_train_samples_list, 
                              'num_test_samples': num_test_samples_list, 
                              'train_acc': [i.item() for i in train_acc_list], 
                              'test_acc': [i.item() for i in test_acc_list],
                              'cultivar_mode_acc': test_mode_acc_list})
    result_df = result_df.merge(marker_family_df, on='gene')
    result_df = result_df[['family', 'gene', 'num_train_samples', 'num_test_samples', 'train_acc', 'test_acc', 'cultivar_mode_acc']]
    return result_df, pred_dict

In [None]:
rgb_result_df, rgb_pred_dict = single_sensor_gene_pred(known_genetic_markers, 'rgb')

In [None]:
d3_result_df, d3_pred_dict = single_sensor_gene_pred(known_genetic_markers, '3d')

In [11]:
pd.options.display.float_format = '{:,.3f}'.format

In [12]:
rgb_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,cultivar_mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,279151,82644,0.653,0.615,0.64
1,leaf_wax,sobic_001G269200_1_51588838,257180,74100,0.634,0.601,0.652
2,leaf_wax,sobic_001G269200_1_51589143,268544,111672,0.638,0.622,0.662
3,leaf_wax,sobic_001G269200_1_51589435,319136,53620,0.65,0.64,0.781
4,dw,sobic_006G067700_1_42805319,311427,124734,0.622,0.604,0.654
5,dw,sobic_006G067700_1_42804037,298316,117852,0.621,0.597,0.653
6,d_locus,sobic_006G147400_1_50898459,320272,58340,0.609,0.539,0.553
7,d_locus,sobic_006G147400_1_50898536,311582,62912,0.601,0.505,0.5
8,d_locus,sobic_006G147400_1_50898315,275004,63664,0.596,0.527,0.553
9,d_locus,sobic_006G147400_1_50898231,237704,53468,0.589,0.617,0.735


In [13]:
d3_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,cultivar_mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,285259,80892,0.602,0.569,0.7
1,leaf_wax,sobic_001G269200_1_51588838,262127,75394,0.594,0.564,0.696
2,leaf_wax,sobic_001G269200_1_51589143,273089,112610,0.602,0.597,0.662
3,leaf_wax,sobic_001G269200_1_51589435,322863,55332,0.626,0.606,0.75
4,dw,sobic_006G067700_1_42805319,313286,127898,0.592,0.584,0.654
5,dw,sobic_006G067700_1_42804037,302262,120344,0.593,0.571,0.681
6,d_locus,sobic_006G147400_1_50898459,322659,59056,0.592,0.536,0.605
7,d_locus,sobic_006G147400_1_50898536,314192,65414,0.577,0.569,0.684
8,d_locus,sobic_006G147400_1_50898315,277623,64800,0.578,0.584,0.658
9,d_locus,sobic_006G147400_1_50898231,237907,54942,0.572,0.578,0.618


In [14]:
def multimodal_mode(rgb_df, d3_df):
    all_pred = pd.concat([rgb_df, d3_df])
    all_pred_mode = all_pred.groupby('cultivar').agg(pd.Series.mode)
    return (all_pred_mode['pred'] == all_pred_mode['gt']).mean()

In [15]:
for marker in rgb_pred_dict.keys():
    print(f'{marker} {multimodal_mode(rgb_pred_dict[marker], d3_pred_dict[marker]):.3f}')

sobic_001G269200_1_51588525 0.640
sobic_001G269200_1_51588838 0.674
sobic_001G269200_1_51589143 0.662
sobic_001G269200_1_51589435 0.750
sobic_006G067700_1_42805319 0.667
sobic_006G067700_1_42804037 0.639
sobic_006G147400_1_50898459 0.579
sobic_006G147400_1_50898536 0.579
sobic_006G147400_1_50898315 0.632
sobic_006G147400_1_50898231 0.735
sobic_006G147400_1_50898523 0.658
sobic_006G147400_1_50898525 0.763
sobic_006G057866_1_40312463 0.800
sobic_006G004400_2_2697734 0.789
sobic_009G229800_1_57040680 0.630


# Fusion

In [16]:
rgb_ebd_result = torch.load('results/snp_pred/s9_jpg_rgb_gene_ds_img_ebd.pth')
d3_edb_result = torch.load('results/snp_pred/s9_jpg_3d_gene_ds_img_ebd.pth')

In [18]:
rgb_ebd = rgb_ebd_result['ebd_128']
d3_ebd = d3_edb_result['ebd_128']

In [21]:
import importlib
import genetic_marker_dataset
importlib.reload(genetic_marker_dataset)
from genetic_marker_dataset import EmbeddingGeneDataset

In [None]:
num_train_samples_list = []
num_test_samples_list = []
train_acc_list = []
test_acc_list = []
test_mode_acc_list = []
pred_dict = {}
rgb_ebd = rgb_ebd.cuda()
d3_ebd = d3_ebd.cuda()


for markers in known_genetic_markers:
    ebd_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', markers, 'multimodal', train=True)
    train_sampler = get_sampler(ebd_ds)
    ebd_dl = torch.utils.data.DataLoader(ebd_ds, batch_size=256, sampler=train_sampler)
    ebd_test_ds = EmbeddingGeneDataset('/data/shared/genetic_marker_datasets/', markers, 'multimodal', train=False)
    ebd_test_dl = torch.utils.data.DataLoader(ebd_test_ds, batch_size=256, shuffle=False)
    print(f'gene: {markers}')
    num_train_samples_list.append(len(ebd_ds))
    num_test_samples_list.append(len(ebd_test_ds))
    gene_pred_fc = nn.Sequential(#nn.Dropout(0.5),
                                 nn.Linear(256, 2))
    gene_pred_fc.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(gene_pred_fc.parameters(), lr=0.01, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    # train
    gene_pred_fc.train()
    for ep in range(3):
        tbar = tqdm(ebd_dl, leave=False)
        total_loss = 0
        for i, (rgb_ebd_index, d3_ebd_index, label, cultivar) in enumerate(tbar):
            optimizer.zero_grad()
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            loss = loss_func(gene_pred, label.cuda())
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            if i % 20 == 0:
                tbar.set_description(f'loss: {total_loss/20:.4f}')
                total_loss = 0
        lr_scheduler.step()
    # train acc
    gene_pred_fc.eval()
    tbar = tqdm(ebd_dl, leave=False)
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for rgb_ebd_index, d3_ebd_index, label, cultivar in tbar:
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            all_preds.append(gene_pred.cpu())
            all_labels.append(label.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    train_acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
    train_acc_list.append(train_acc)
    # test
    tbar = tqdm(ebd_test_dl, leave=False)
    all_preds = []
    all_labels = []
    all_cultivars = []
    with torch.no_grad():
        for rgb_ebd_index, d3_ebd_index, label, cultivar in tbar:
            rgb_e = rgb_ebd[rgb_ebd_index.cuda()]
            d3_e = d3_ebd[d3_ebd_index.cuda()]
            ebd = torch.cat([rgb_e, d3_e], dim=1)
            gene_pred = gene_pred_fc(ebd)
            all_preds.append(gene_pred.cpu())
            all_labels.append(label.cpu())
            all_cultivars.append(cultivar)
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_cultivars = [j for i in all_cultivars for j in i]
    acc = (all_preds.argmax(dim=1) == all_labels).float().mean()
    pred_df = pd.DataFrame({'pred': all_preds.argmax(dim=1).numpy(), 'gt': all_labels.numpy(), 'cultivar': all_cultivars})
    cultivar_mode_df = pred_df.groupby('cultivar').agg(pd.Series.mode)
    cultivar_mode_acc = (cultivar_mode_df['pred'] == cultivar_mode_df['gt']).mean()
    test_acc_list.append(acc)
    test_mode_acc_list.append(cultivar_mode_acc)
    pred_dict[marker] = pred_df
    save_path = f'result/snp_pred/{markers}_mm_pred.pt'
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    torch.save(gene_pred_fc.state_dict(), save_path)
    print(f'gene: {markers}, train_imgs: {len(ebd_ds)}, test_imgs: {len(ebd_test_ds)}, train_acc:{train_acc:.4f}, acc: {acc:.4f}')

In [39]:
result_df = pd.DataFrame({'gene': known_genetic_markers, 
                          'num_train_samples': num_train_samples_list, 
                          'num_test_samples': num_test_samples_list, 
                          'train_acc': [i.item() for i in train_acc_list], 
                          'test_acc': [i.item() for i in test_acc_list],
                          'mode_acc': test_mode_acc_list})

In [40]:
result_df = result_df.merge(marker_family_df, on='gene')

In [42]:
pd.options.display.float_format = '{:,.3f}'.format
multimodal_result_df = result_df[['family', 'gene', 'num_train_samples', 'num_test_samples', 'train_acc', 'test_acc', 'mode_acc']]
multimodal_result_df

Unnamed: 0,family,gene,num_train_samples,num_test_samples,train_acc,test_acc,mode_acc
0,leaf_wax,sobic_001G269200_1_51588525,277497,78208,0.698,0.626,0.7
1,leaf_wax,sobic_001G269200_1_51588838,254095,73738,0.696,0.574,0.587
2,leaf_wax,sobic_001G269200_1_51589143,265324,108760,0.681,0.63,0.676
3,leaf_wax,sobic_001G269200_1_51589435,314150,54082,0.704,0.652,0.75
4,dw,sobic_006G067700_1_42805319,305050,123898,0.66,0.627,0.756
5,dw,sobic_006G067700_1_42804037,293891,116642,0.667,0.608,0.681
6,d_locus,sobic_006G147400_1_50898459,314483,56728,0.668,0.572,0.632
7,d_locus,sobic_006G147400_1_50898536,305214,64230,0.658,0.583,0.658
8,d_locus,sobic_006G147400_1_50898315,269254,63418,0.655,0.595,0.711
9,d_locus,sobic_006G147400_1_50898231,231131,52970,0.658,0.639,0.765
