In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *
import json
from tqdm import tqdm

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.fasta as fasta

In [3]:
# Initialise random libs and setup cudnn
random_seed = 25 # MAGIC!!
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Constants

### Data

In [4]:
BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = True
TEACHER_OUTPUT = 'bonito-csv' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

In [5]:
METRICS = [metric.ctc_accuracy(ALPHABET, 5)]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

### Train/Predict

In [6]:
LR = 1e-3  # default learning rate
BS = 2**6  # batch size
EPOCHS = 5
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [7]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3+1
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH

## Load data

In [8]:
# Read data from feather
if KNOWLEGDE_DISTILLATION:
    data, teacher = f.load_training_data_with_teacher(FEATHER_FOLDER, TEACHER_OUTPUT)
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, teacher=teacher, drop_last=DROP_LAST)
else:
    data = f.load_training_data(FEATHER_FOLDER) 
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)

# Convert to databunch
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [10]:
_ctc_loss = metric.CtcLoss(WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)

loss_funcs = {}
for t in [1,2,4,8,16,32]:
    for a in np.arange(0,1.1,.1):
        loss_funcs[f't={t},a={a}'] = metric.KdLoss(alpha=a, temperature=t, label_loss=_ctc_loss).loss()

In [11]:
optimizers = {'AdamW': partial(torch.optim.AdamW, amsgrad=True, lr=LR)}

In [None]:
schedulars = get_schedulars(*[0, LR):

In [12]:
## Model_name, ctc_accuracy, loss_function, optimizer
models = [partial(model_file.model, DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)]
with open('hyper-output', 'w') as f:
    f.write('Model_name, ctc_accuracy, loss_function, optimizer')
    for model in models:
        for l_key, loss in loss_funcs.items():
            for o_key, optim in optimizers.items():
                m, MODEL_NAME = model()
                MODEL_NAME = f'{MODEL_NAME}-windowsize={WINDOW_SIZE}'
                MODEL_DIR = f'weights/{MODEL_NAME}'

                # Create learner
                learner = Learner(databunch, m, loss_func=loss, model_dir=MODEL_DIR, metrics=METRICS, opt_func=optim)
                for s_key, sched in schedulars.items():
                    sched(learner)
                    
                    # FIT
                    learner.fit(EPOCHS, lr=LR)
                    f.write(f'{MODEL_NAME}, {learner.validate()[1]}, {l_key}, {o_key}, {}\n')
                    del m

epoch,train_loss,valid_loss,ctc_accuracy,time
0,1.480871,1.361229,0.629004,00:33
1,1.35006,1.264773,0.582633,00:32
2,1.249974,1.207545,0.585939,00:32
3,1.155713,1.210925,0.614279,00:32
4,1.072807,1.260624,0.623069,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,152.23558,139.219681,0.579115,00:33
1,141.246643,132.831772,0.591673,00:33
2,134.520432,130.154053,0.611157,00:33
3,127.954872,135.115662,0.622013,00:33
4,122.72654,2831.031006,0.669897,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,301.43399,275.621582,0.577756,00:33
1,279.168915,261.706451,0.598185,00:33
2,265.725067,255.964752,0.60486,00:33
3,253.875717,256.764862,0.629299,00:33
4,244.831467,251.014832,0.643323,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,449.97345,418.433563,0.584143,00:33
1,418.639496,396.30957,0.593203,00:33
2,398.220032,384.035004,0.609397,00:33
3,379.99765,390.731628,0.62689,00:33
4,361.16449,401.165039,0.638774,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,598.668152,555.751343,0.572174,00:33
1,555.669006,525.318237,0.599253,00:33
2,530.196655,507.764801,0.611474,00:33
3,507.65274,495.667603,0.635287,00:33
4,484.759674,501.287567,0.63724,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,748.266479,705.245239,0.575646,00:33
1,695.551331,675.049927,0.60065,00:33
2,665.492432,673.070984,0.597785,00:33
3,638.449646,630.67804,0.607121,00:33
4,608.668823,615.791626,0.627896,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,900.335571,830.708435,0.577714,00:33
1,836.911499,787.256836,0.600804,00:33
2,797.445618,761.590454,0.615981,00:33
3,760.866028,756.997009,0.636908,00:33
4,723.571411,762.758484,0.635906,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1047.558716,967.45636,0.583629,00:33
1,973.752075,924.282898,0.591586,00:33
2,925.401917,895.182922,0.606412,00:33
3,881.836853,883.537109,0.618988,00:33
4,836.338196,901.297668,0.633197,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1198.093262,1103.411865,0.58306,00:33
1,1112.939697,1045.659302,0.597879,00:33
2,1059.002075,1032.511475,0.610313,00:33
3,1009.544128,1005.66864,0.628912,00:33
4,958.52179,1024.940063,0.633009,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1349.347656,1248.700562,0.583946,00:33
1,1252.066895,1175.556885,0.60013,00:33
2,1193.14563,1143.321777,0.610732,00:33
3,1137.815674,1131.731445,0.625391,00:33
4,1078.590576,1144.014404,0.640923,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1496.81958,1364.162598,0.580711,00:33
1,1387.225464,1301.146729,0.592261,00:33
2,1363.2229,1366.778198,0.589039,00:33
3,1308.814819,1239.517334,0.617868,00:33
4,1253.328369,1219.019653,0.638742,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1.573452,1.432275,0.600719,00:33
1,1.472413,1.3998,0.604691,00:33
2,1.425584,1.371968,0.598182,00:33
3,1.374554,1.291331,0.592637,00:33
4,1.30832,1.223976,0.59222,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,522.898132,494.697906,0.580066,00:33
1,488.40921,463.669617,0.596458,00:33
2,463.932892,454.518646,0.610743,00:33
3,438.718445,451.559845,0.626985,00:33
4,412.03775,480.304779,0.629297,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1044.269775,983.887451,0.583249,00:33
1,972.545593,943.851257,0.590707,00:33
2,922.946045,911.325378,0.608373,00:33
3,873.218384,904.754578,0.625385,00:33
4,824.449341,943.474426,0.617472,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1563.754395,1461.147949,0.586079,00:33
1,1459.557861,1418.636841,0.585472,00:33
2,1382.213623,1348.484375,0.616148,00:33
3,1304.482666,1361.864868,0.624642,00:33
4,1227.575439,1484.590942,0.612333,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,2088.079346,1961.12207,0.584361,00:33
1,1946.249268,1871.315552,0.592403,00:33
2,1848.095581,1791.660278,0.616181,00:33
3,1748.511719,1794.044678,0.631005,00:33
4,1643.641357,1880.29187,0.620043,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,2598.197021,2508.451904,0.581682,00:33
1,2421.142334,2310.531494,0.596969,00:33
2,2296.310303,2254.878662,0.621644,00:33
3,2173.620361,2317.392822,0.629769,00:33
4,2045.428833,2326.060791,0.630178,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,3132.187988,2915.860596,0.583069,00:33
1,2921.106934,2802.245117,0.590595,00:33
2,2768.942871,2695.767334,0.616167,00:33
3,2611.885254,2708.095947,0.618357,00:33
4,2456.826172,2920.748535,0.616557,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,3669.260742,3452.050781,0.578343,00:33
1,3424.221924,3295.638184,0.583974,00:33
2,3250.915527,3201.459961,0.600966,00:33
3,3076.188721,3216.391602,0.610462,00:33
4,2896.948975,3426.719727,0.607617,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,4150.971191,3868.551514,0.587826,00:33
1,3868.011963,3687.412842,0.600014,00:33
2,3674.621094,3567.81958,0.623641,00:33
3,3492.271973,3572.557861,0.627344,00:33
4,3286.875244,3769.571289,0.620192,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,4695.220703,4440.814453,0.584617,00:33
1,4374.753906,4318.156738,0.590682,00:33
2,4149.567871,4011.865479,0.61983,00:33
3,3922.711914,4085.568848,0.626655,00:33
4,3682.689209,4140.733398,0.62682,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,5231.064453,4918.289062,0.579906,00:33
1,4879.260742,4716.303223,0.591497,00:33
2,4628.892578,4514.47998,0.611818,00:33
3,4379.70166,4559.177734,0.626464,00:33
4,4113.796875,4729.834961,0.632334,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1.491528,1.34351,0.585599,00:33
1,1.357778,1.255316,0.574447,00:33
2,1.25888,1.190328,0.58675,00:33
3,1.158979,1.172765,0.625205,00:33
4,1.070406,1.216367,0.636954,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1468.574829,1382.263062,0.589388,00:33
1,1367.66687,1319.759521,0.584652,00:33
2,1289.060181,1279.800781,0.583359,00:33
3,1206.848389,1242.14978,0.6139,00:33
4,1121.925781,1264.499023,0.62457,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,2938.718262,2760.800537,0.589223,00:33
1,2743.444092,2709.862549,0.590965,00:33
2,2587.785889,2593.95874,0.589,00:33
3,2423.379395,2527.303711,0.598777,00:33
4,2246.006592,2605.013428,0.605605,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,4377.75,4135.435547,0.591608,00:33
1,4087.376221,3947.205811,0.587829,00:33
2,3853.723389,3835.037842,0.595211,00:33
3,3603.21582,3706.682617,0.611403,00:33
4,3357.728271,3930.6521,0.6151,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,5870.582031,5514.661621,0.582572,00:33
1,5468.174805,5234.197754,0.583697,00:33
2,5152.357422,5095.216797,0.589056,00:33
3,4815.057617,5064.251953,0.601083,00:33
4,4474.710938,5205.90332,0.613785,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,7299.858398,6868.313965,0.582255,00:33
1,6805.032715,6530.074219,0.585179,00:33
2,6403.018066,6312.873047,0.588905,00:33
3,5984.921875,6343.970703,0.592367,00:33
4,5565.316406,6360.208008,0.618067,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,8771.058594,8272.238281,0.581664,00:33
1,8161.308105,7844.947754,0.579987,00:33
2,7684.508789,7645.882812,0.590535,00:33
3,7186.156738,7576.163086,0.606519,00:33
4,6714.17627,7864.441895,0.607147,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,10284.935547,9631.291016,0.591335,00:33
1,9549.869141,9130.883789,0.58012,00:33
2,8988.064453,8773.810547,0.589082,00:33
3,8416.986328,9372.749023,0.580935,00:33
4,7915.561035,9257.705078,0.587951,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,11689.399414,11006.411133,0.586897,00:33
1,10913.168945,10481.59375,0.580452,00:33
2,10299.254883,10116.240234,0.591378,00:33
3,9631.18457,10182.418945,0.601332,00:33
4,8985.09082,10207.041016,0.601332,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,13198.545898,12506.407227,0.58677,00:33
1,12291.030273,11905.231445,0.581731,00:33
2,11577.857422,11409.576172,0.595073,00:33
3,10824.561523,11604.516602,0.601599,00:33
4,10085.955078,12170.899414,0.606733,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,14592.74707,13747.530273,0.580338,00:33
1,13629.054688,13251.955078,0.583354,00:33
2,12877.075195,13019.216797,0.587837,00:33
3,12063.639648,13119.439453,0.590412,00:33
4,11292.25293,13699.763672,0.598326,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,1.503366,1.352832,0.596706,00:33
1,1.378215,1.261687,0.580243,00:33
2,1.279774,1.181125,0.596418,00:33
3,1.178715,1.164912,0.635558,00:33
4,1.078201,1.204368,0.641531,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,2589.944092,2444.306885,0.602157,00:33
1,2406.909424,2310.576904,0.602961,00:33
2,2252.824219,2244.33374,0.583902,00:33
3,2088.483398,2257.28125,0.588537,00:33
4,1933.892944,2275.526611,0.593426,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,5160.828125,4874.745117,0.613293,00:33
1,4800.250977,4659.773438,0.612128,00:33
2,4510.489258,4508.509277,0.582697,00:33
3,4204.121094,4521.929199,0.591567,00:33
4,3898.005127,4464.377441,0.594836,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,7725.854004,7259.774414,0.59578,00:33
1,7179.766113,6961.810059,0.593785,00:33
2,6738.964844,6902.935547,0.602289,00:33
3,6286.810059,6540.906738,0.584166,00:33
4,5807.778809,6691.576172,0.592721,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,10308.719727,9629.15918,0.595648,00:33
1,9577.322266,9274.138672,0.590329,00:33
2,8984.635742,8711.731445,0.588041,00:33
3,8381.754883,8884.549805,0.587013,00:33
4,7776.286133,9820.355469,0.580717,00:33


epoch,train_loss,valid_loss,ctc_accuracy,time
0,12929.353516,12131.610352,0.59714,00:33


KeyboardInterrupt: 