In [36]:
import random
import torch
import numpy as np
from kinpfn.model import KINPFN

def set_seed(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Data for RAW 264.7 and BMDM cells (IL1alpha, IL1Ba, TNFa)
raw_2647_il1a = [
    14, 0, 10, 2, 4, 4, 1, 4, 1, 2, 0, 25, 89, 89, 43, 19, 3, 1, 1, 4, 96, 109, 2, 5, 8, 3, 1, 0, 1, 299, 69, 
    81, 86, 15, 12, 0, 0, 1, 319, 0, 420, 13, 38, 234, 3, 9, 35, 0, 11, 4, 1, 2, 0, 0, 0, 42, 43, 1, 7, 3, 
    3, 41, 2, 0, 2, 1, 0, 0, 3, 2, 89, 0, 1, 2, 0, 4, 63, 83, 51, 123, 40, 31, 58, 2, 3, 12, 40, 7, 0, 2, 0, 
    111, 5, 59, 0, 0, 0, 0, 30, 0, 2, 0, 1, 8, 5, 63, 0, 2, 5, 1, 1, 3, 0, 29, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 
    0, 0, 0, 0, 187, 1, 1, 0, 0, 25, 0, 0, 0, 0, 0, 183, 0, 180, 0, 1, 12, 35, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 
    0, 14, 13, 0, 106, 0, 0, 23, 11, 56, 76, 23, 40, 6, 0, 0, 2, 3, 1, 0, 96, 52, 3, 18, 55, 19, 4, 15, 27, 0, 
    4, 3, 2, 27, 27, 1, 0, 0, 0, 0, 5, 0, 1, 5, 0, 12, 2, 0, 267, 8, 54, 2, 18, 1, 328, 4, 3, 1, 0, 0, 16, 12, 
    1, 33, 66, 2, 1, 7, 2, 2, 1, 1, 37, 2, 2, 12, 4, 8, 134, 5, 9, 0, 2, 2, 8, 16, 23, 7, 7, 3, 13, 134, 0, 
    0, 11, 0, 18, 3, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 347, 0, 1, 168, 5, 2, 208, 3, 2, 7, 6, 51, 14, 36, 
    5, 3, 11, 235, 14, 444, 134, 2, 0, 2, 11, 12, 63, 9, 485, 1, 1, 7, 158, 11, 2, 199, 1, 12, 89, 7, 0, 3, 
    382, 6, 1, 35, 5, 45, 18, 7, 18, 10, 1, 0, 618
]

bmdm_il1a = [
    291, 84, 29, 55, 117, 67, 34, 27, 111, 4, 7, 69, 16, 1, 24, 17, 2, 102, 43, 84, 15, 3, 7, 4, 153, 64, 13, 
    0, 0, 1, 34, 29, 6, 163, 18, 25, 81, 45, 43, 9, 12, 2, 9, 6, 228, 1, 35, 6, 5, 54, 4, 3, 1, 10, 0, 3, 111, 
    0, 2, 2, 6, 48, 3, 22, 50, 17, 164, 291, 2, 2, 3, 2, 2, 0, 5, 2, 40, 2, 1, 0, 146, 2, 2, 9, 10, 3, 3, 2, 
    14, 5, 77, 7, 62, 65, 3, 7, 37, 7, 20, 1, 7, 25, 10, 4, 45, 106, 99, 91, 124, 16, 27, 23, 15, 177, 304, 14, 
    75, 1, 7, 4, 6, 113, 106, 127, 43, 33, 33, 26, 110, 13, 6, 216, 7, 141, 22, 11, 14, 6
]

raw_2647_il1b = [
    696, 435, 500, 403, 13, 23, 5, 8, 326, 398, 4, 16, 14, 19, 654, 129, 108, 22, 5, 19, 136, 4, 7, 309, 329, 0, 3, 
    15, 21, 29, 150, 89, 60, 405, 48, 327, 238, 71, 46, 102, 706, 543, 457, 190, 1, 406, 151, 199, 54, 87, 25, 224, 
    14, 1, 10, 80, 0, 35, 40, 19, 910, 191, 121, 12, 15, 54, 70, 140, 115, 26, 36, 2, 337, 2, 75, 13, 5, 10, 1, 
    100, 46, 165, 4, 5, 73, 20, 35, 36, 90, 206, 208, 244, 96, 75, 7, 33, 604, 242, 7, 98, 16, 245, 538, 17, 0, 280, 
    7, 1, 27, 202, 3, 33, 439, 1, 0, 1, 1, 360, 36, 196, 219, 223, 0, 160, 524, 24, 42, 0, 0, 276, 534, 153, 217, 
    310, 272, 0, 408, 102, 918, 10, 44, 250, 21, 319, 470, 607, 415, 114, 593, 15, 25, 4, 211, 34, 133, 108, 149, 
    449, 198, 49, 60, 671, 557, 247, 70, 166, 18, 30, 21, 561, 16, 935, 352, 456, 453, 7, 505, 227, 423, 794, 130, 
    98, 173, 156, 28, 324, 176, 290, 30, 507, 677, 309, 13, 414, 60, 487, 43, 645, 354, 286, 670, 135, 501, 242, 31, 
    155, 489, 15, 41, 54, 442, 228, 85, 526, 1003, 10, 381, 404, 483, 21, 345, 274, 31, 96, 71, 422, 939, 700, 795, 
    305, 81, 5, 0, 0, 1, 207, 379, 1384, 10, 221, 303, 1065, 15, 5, 672, 397, 116, 171, 516, 333, 391, 390, 25, 249, 
    3, 296, 635, 79, 395, 27, 35, 198, 179, 373, 165, 11, 2, 8, 201, 68, 109, 12, 227, 743, 751, 711, 334, 27, 87, 
    479, 32, 608, 45, 38, 782, 33, 145, 19, 325, 69, 21, 15, 179, 301, 6, 7, 4, 402, 207, 153, 106, 23, 161, 2, 3, 
    6, 19, 280, 60, 54, 240, 153, 706, 689, 11, 22, 125, 343, 9, 4, 6, 3, 123, 138, 44, 114, 10, 179, 25, 62, 35, 
    12, 21, 17, 4, 1, 23, 8, 2, 233, 2, 172, 107, 3, 3, 0, 8, 162, 5, 24, 99, 336, 260, 60, 88, 37, 5, 608, 357, 7, 
    21, 408, 23, 0, 238, 393, 74, 2, 10, 121, 321, 56, 1, 92, 26, 365, 365, 160, 4, 163, 188, 2, 9, 2, 43, 282, 1, 
    479, 223, 380, 107, 111, 204, 6, 15, 67, 59, 37, 192, 252, 72, 377, 471, 289, 181, 83, 40, 29, 29, 169, 461, 
    33, 2, 494, 377, 131, 335, 26, 6, 3, 385, 414, 59, 96, 131, 177, 11, 2, 26, 154, 188, 12, 152, 344, 483, 94, 
    296, 54, 381, 13, 104, 10, 44, 70, 342, 195, 269, 605, 2, 11, 180, 26, 228, 567, 6, 20, 410, 241, 221, 160, 91, 
    16, 3, 147, 7, 1, 8, 235, 4, 15, 138, 220, 0, 0, 0, 10, 370, 64, 152, 14, 1, 20, 275, 187, 31, 63, 71, 117, 
    307, 90, 10, 258, 0, 47, 0, 21, 0, 10, 290, 2, 1, 15, 4, 16, 34, 495, 425, 58, 619, 498, 29, 599, 173, 682, 
    432, 592, 70, 49, 590, 216, 427, 42, 46, 13, 49, 358, 277, 290, 199, 437, 352, 339, 115, 188, 40, 70, 443, 506, 
    28, 38, 6, 21, 894, 251, 152, 75, 10, 317, 23, 4, 10, 355, 173, 251, 490, 207, 10, 620, 321, 461, 826, 172, 
    375, 4, 12, 212, 94, 185, 30, 732, 588, 417, 82, 2, 255, 5, 31, 6, 872, 704, 77, 851, 1076, 536, 337, 566, 569, 
    840, 930, 628, 447, 705, 21, 204, 712, 561, 128, 756, 806, 1031, 23, 638, 350, 412, 900, 193, 209, 74, 141, 
    429, 590, 953, 1272, 68, 1060, 767, 678, 756, 506, 29, 440, 503, 176, 11, 284, 811, 911, 220, 496, 38, 469, 249, 
    162, 540, 791, 16, 156, 362, 40, 71, 540, 26, 33, 337, 868, 17, 209, 131, 33, 450, 19, 53, 1, 1, 11, 115, 4, 10, 
    3, 5, 711, 435, 291, 457, 3, 0, 1, 4, 4, 2, 4, 231, 387, 2, 301, 9, 410, 587, 51, 64, 4, 243, 136, 367, 164, 19, 
    11, 421, 152, 2, 11, 4, 5, 2, 689, 175, 1, 20, 17, 58, 122, 535, 160, 137, 10, 8, 278, 166, 19, 20, 606, 500, 
    386, 276, 313, 19, 6, 4
]

bmdm_il1b = [
    195, 116, 20, 510, 916, 17, 9, 7, 107, 72, 120, 205, 22, 193, 171, 38, 1, 124, 20, 36, 18, 7, 4, 8, 0, 23, 5, 
    6, 24, 11, 33, 5, 115, 154, 13, 1, 17, 12, 1, 1, 49, 478, 473, 15, 4, 58, 2, 125, 11, 441, 31, 5, 274, 14, 31, 
    3, 5, 4, 67, 97, 35, 35, 2, 283, 24, 84, 838, 460, 2, 124, 156, 43, 222, 10, 20, 54, 58, 193, 15, 1001, 24, 239, 
    136, 18, 14, 99, 541, 592, 20, 23, 5, 0, 18, 60, 16, 1, 9, 86, 2, 8, 17, 175, 58, 108, 11, 8, 1, 93, 355, 556, 
    12, 14, 494, 309, 24, 15, 81, 2, 325, 40, 494, 225, 26, 44, 284, 0, 2, 0, 501, 4, 128, 24, 0, 35, 13, 372, 27, 
    15, 128, 94, 170, 2, 229, 294, 0, 0, 14, 38, 3, 0, 279, 50, 22, 32, 10, 29, 0, 1, 22, 117, 26, 9, 18, 1, 7, 10, 
    109, 85, 14, 159, 195, 32, 6, 1, 4, 66, 152, 142, 64, 156, 213, 93, 102, 0, 24, 4, 8, 7, 66, 13, 82, 2, 0, 10, 
    32, 0, 46, 1, 1, 19, 182, 156, 348, 319, 99, 449, 75, 68, 104, 99, 95, 292, 41, 151, 127, 26, 78, 214, 281, 322, 
    85, 57, 338, 623, 137, 28, 54, 208, 144, 387, 159, 73, 83, 166, 67, 68, 73, 57, 65, 432, 70, 62, 47, 199, 192, 
    162, 219, 181, 224, 216, 149, 144, 30, 80, 186, 62, 42, 44, 41, 17, 20, 61, 22, 400, 232, 273, 155, 198, 126, 
    105, 155, 323, 118, 91, 59, 76, 51, 87, 44, 43, 156, 32, 26, 75, 183, 60, 522, 259, 107, 72, 151, 516, 221, 37, 
    76, 77, 54, 176, 220, 116, 269, 59, 190, 55, 327, 116, 27, 43, 60, 47, 175, 154, 1, 8, 37, 1, 14, 18, 346, 3, 6, 
    0, 137, 3, 11, 36, 86, 305, 71, 278, 4, 44, 44, 343, 174, 193, 28, 41, 44, 82, 50, 4, 5, 43, 69, 47, 146, 63, 0, 
    0, 214, 27, 105, 3, 186, 16, 187, 177, 0, 3, 0, 6, 4, 76, 2, 42, 18, 439, 121, 3, 0, 6, 66, 4, 10, 13, 2, 5, 1, 
    123, 7, 36, 148, 150, 34, 85, 16, 5, 5, 249, 585, 88, 1, 4, 0, 4, 247, 14, 4, 3, 7, 8, 14, 67, 604, 272, 463, 
    130, 60, 32, 5, 1, 20, 4, 75, 12, 125, 62, 0, 744, 51, 1, 585, 133, 96, 36, 4, 0, 26, 10, 86, 60, 116, 40, 81, 
    49, 66, 142, 20, 1, 19, 167, 230, 105, 32, 24, 91, 64, 12, 3, 271, 24, 28, 107, 135, 53, 11, 16, 4, 88, 273, 18, 
    14, 5, 21, 11, 230, 134, 6, 27, 55, 73, 4, 3, 50, 127, 24, 3, 107, 71, 48, 130, 107, 103, 50, 191, 36, 37, 88, 
    308, 17, 87, 104, 48, 30, 2, 35, 44, 24, 13, 59, 21, 119, 380, 10, 58, 40, 9, 703, 205, 195, 6, 370, 280, 155, 
    129, 320, 36, 11, 18, 146, 164, 583, 209, 196, 12, 21, 66, 436, 24, 87, 9, 393, 23, 240, 17, 102, 20, 205, 288, 
    12, 42, 6, 12, 371, 164, 126, 40, 7, 7, 9, 121, 23, 49, 2, 7, 151, 58, 107, 131, 128, 139, 443, 2, 46, 54, 1, 0, 
    10, 64, 83, 55, 2, 2, 1, 357, 3, 14, 174, 25, 28, 35, 16, 174, 82, 57, 84, 5, 5, 63, 127, 77, 42, 14, 173, 225, 
    85, 317, 185, 36, 227, 77, 321, 60, 16, 16, 88, 32, 145, 10, 325, 347, 468, 307, 218, 97, 137, 6, 110, 5, 34, 85, 
    105, 74, 91, 69, 126, 228, 11, 39, 5, 5, 257, 27, 1, 3, 5, 31, 33, 1, 212, 7, 3, 78, 197, 11, 76, 8, 47, 1, 0, 
    16, 66, 183, 128, 4, 90, 4, 9, 126, 140, 84, 85, 6, 2, 0, 2, 0, 2, 64, 196, 11, 49, 4, 19, 8, 59, 78, 2, 246, 
    66, 279, 0, 4, 97, 38, 120, 27, 6, 278, 483, 447, 88, 191, 19, 541, 109, 99, 272, 270, 264, 11, 71, 7, 59, 19, 
    48, 88, 23, 355, 15, 24, 27, 24, 129, 19, 17, 85, 365, 125, 116, 263, 103, 49, 509, 427, 29, 118, 770, 180, 271, 
    188
]


raw_2647_tnfa = [
    66, 258, 120, 408, 506, 144, 206, 188, 359, 307, 182, 54, 376, 279, 125, 216, 6, 201, 395, 277, 209, 218, 261, 
    458, 583, 178, 121, 187, 465, 371, 210, 453, 160, 180, 243, 312, 392, 96, 380, 194, 217, 187, 99, 161, 202, 140, 
    162, 165, 134, 83, 155, 509, 225, 633, 232, 186, 99, 338, 351, 123, 214, 316, 352, 126, 170, 173, 236, 103, 624, 
    368, 97, 325, 199, 242, 150, 417, 367, 243, 140, 153, 102, 188, 111, 335, 147, 135, 201, 333, 347, 97, 59, 231, 
    83, 138, 313, 175, 239, 25, 475, 301, 331, 326, 263, 426, 109, 214, 297, 204, 441, 355, 65, 247, 276, 350, 325, 
    808, 165, 53, 279, 420, 75, 679, 372, 118, 282, 171, 461, 196, 330, 470, 122, 399, 543, 211, 100, 121, 300, 90, 
    162, 400, 423, 480, 155, 313, 385, 476, 272, 299, 274, 281, 136, 177, 423, 119, 159, 431, 398, 282, 684, 119, 
    129, 256, 49, 172, 140, 231, 319, 299, 131, 216, 252, 130, 55, 35, 405, 50, 36, 213, 278, 389, 473, 432, 132, 135, 
    101, 100, 457, 331, 311, 170, 36, 214, 130, 229, 372, 387, 181, 159, 238, 329, 236, 259, 493, 8, 440, 203, 554, 
    275, 373, 451, 591, 257, 249, 151, 80, 843, 431, 244, 242, 484, 831, 520, 837, 359, 804, 848, 805, 215, 435, 415, 
    660, 189, 261, 583, 457, 715, 458, 715, 397, 559, 268, 172, 348, 473, 551, 472, 556, 410, 296, 373, 900, 403, 
    264, 431, 426, 439, 320, 116, 331, 337, 621, 380, 182, 134, 324, 764, 737, 273, 355, 154, 426, 295, 821, 168, 
    229, 328, 144, 218, 516, 319, 392, 256, 546, 460, 281, 808, 445, 448, 212, 1006, 267, 206, 616, 311, 126, 321, 
    171, 174, 631, 551, 488, 623, 253, 218, 292, 259, 221, 465, 339, 426, 925, 241, 286, 596, 1334, 692, 205, 306, 
    476, 665, 466, 512, 392, 276, 367, 1018, 814, 89, 281, 670, 428, 332, 755, 568, 500, 512, 316, 448, 318, 521, 
    256, 608, 251, 364, 1001, 715, 612, 402, 612, 580, 672, 671, 417, 369, 158, 393
]


bmdm_tnfa = [
    150, 207, 199, 247, 185, 445, 309, 80, 98, 106, 135, 201, 18, 253, 335, 562, 551, 99, 325, 7, 503, 286, 183, 
    193, 174, 131, 142, 103, 340, 410, 326, 65, 222, 112, 306, 119, 55, 91, 177, 280, 342, 273, 226, 188, 72, 257, 
    72, 36, 19, 161, 11, 6, 2, 15, 94, 34, 100, 110, 48, 7, 5, 31, 479, 244, 238, 519, 851, 378, 435, 161, 541, 
    515, 409, 376, 484, 275, 228, 714, 117, 501, 289, 767, 262, 615, 204, 318, 791, 226, 166, 329, 543, 458, 486, 
    404, 281, 398, 468, 390, 170, 279, 462, 133, 362, 206, 267, 365, 358, 157, 247, 465, 348, 232, 176, 199, 371, 
    225, 238, 565, 213, 122, 53, 400, 149, 199, 367, 260, 690, 164, 138, 469, 132, 240, 375, 320, 212, 258, 75, 563, 
    170, 259, 497, 349, 238, 407, 65, 310, 501, 558, 218, 570, 258, 177, 62, 188, 412, 224, 407, 453, 386, 268, 675, 
    304, 191, 166, 323, 49, 182, 223, 274, 313, 216, 376, 169, 439, 168, 140, 179, 407, 113, 150, 211, 198, 248, 
    263, 171, 194, 137, 331, 369, 319, 159, 407, 108, 176, 268, 304, 119, 186, 136, 174, 166, 445, 299, 279, 267, 
    249, 210, 276, 264, 208, 401, 389, 526, 402, 213, 445, 234, 527, 175, 219, 416, 434, 294, 285, 124, 164, 319, 
    350, 362, 280, 287, 247, 312, 106, 278, 258, 306, 230, 356, 286, 71, 193, 28, 40, 208, 240, 128, 12, 133, 279, 
    318, 143, 59, 249, 241, 106, 278, 130, 111, 195, 434, 363, 168, 79, 269, 194, 183, 103, 157, 255, 56, 81, 222, 
    165, 248, 133, 20, 88, 117, 57, 152, 67, 36, 48, 166, 246, 81, 16, 76, 19, 26, 25, 64, 52, 88, 29, 80, 66, 102, 
    420, 69, 103, 58, 30, 14, 319, 20, 64, 96, 99, 306, 192, 161, 30, 292, 147, 63, 312, 208, 137, 24, 32
]


# Function to compute CDF
def compute_cdf(data):
    sorted_data = np.sort(data)
    cdf = np.arange(1, len(data) + 1) / len(data)
    return sorted_data, cdf

# Compute CDFs for all RNA data
raw_il1a_sorted, raw_il1a_cdf = compute_cdf(raw_2647_il1a)
bmdm_il1a_sorted, bmdm_il1a_cdf = compute_cdf(bmdm_il1a)

raw_il1b_sorted, raw_il1b_cdf = compute_cdf(raw_2647_il1b)
bmdm_il1b_sorted, bmdm_il1b_cdf = compute_cdf(bmdm_il1b)

raw_tnfa_sorted, raw_tnfa_cdf = compute_cdf(raw_2647_tnfa)
bmdm_tnfa_sorted, bmdm_tnfa_cdf = compute_cdf(bmdm_tnfa)

print("Gene Experiment Ground Truth Data")
# Create two plots for Raw 264.7 and BMDM cells
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6))

# Raw 264.7 cells
axes[0].plot(raw_il1a_sorted + 1, raw_il1a_cdf, label='IL1α', color='#6EC4E8', marker='o')
axes[0].plot(raw_il1b_sorted + 1, raw_il1b_cdf, label='IL1β', color='#1D3E99', marker='o')
axes[0].plot(raw_tnfa_sorted + 1, raw_tnfa_cdf, label='TNFα', color='#D8342C', marker='o')

# BMDM cells
axes[1].plot(bmdm_il1a_sorted + 1, bmdm_il1a_cdf, label='IL1α', color='#6EC4E8', marker='o')
axes[1].plot(bmdm_il1b_sorted + 1, bmdm_il1b_cdf, label='IL1β', color='#1D3E99', marker='o')
axes[1].plot(bmdm_tnfa_sorted + 1, bmdm_tnfa_cdf, label='TNFα', color='#D8342C', marker='o')

axes[0].set_xlabel(r'$\log_{10}$(mRNA + 1)', fontsize=18)
axes[0].set_ylabel('Cumulative Probability', fontsize=18)
axes[0].set_title('RAW 264.7', fontsize=16)
axes[0].set_xscale('log')
axes[0].tick_params(axis='both', labelsize=12)
axes[0].legend()

axes[1].set_xlabel(r'$\log_{10}$(mRNA + 1)', fontsize=18)
axes[1].set_ylabel('Cumulative Probability', fontsize=18)
axes[1].set_title('BMDM', fontsize=16)
axes[1].set_xscale('log')
axes[1].tick_params(axis='both', labelsize=12)
axes[1].legend()


In [None]:
# Load the trained model
model_path = "../../../../models/final_kinpfn_model_1400_1000_1000_86_50_2.5588748050825984e-05_256_4_512_8_0.0_0.0.pt"

kinpfn = KINPFN(
    model_path=model_path,
)
trained_model = kinpfn.model

if trained_model is not None:
    print("Load trained model!")
else:
    print("No trained model found!")
    exit()

In [39]:
raw_2647_il1a = np.array([
    14, 0, 10, 2, 4, 4, 1, 4, 1, 2, 0, 25, 89, 89, 43, 19, 3, 1, 1, 4, 96, 109, 2, 5, 8, 3, 1, 0, 1, 299, 69, 
    81, 86, 15, 12, 0, 0, 1, 319, 0, 420, 13, 38, 234, 3, 9, 35, 0, 11, 4, 1, 2, 0, 0, 0, 42, 43, 1, 7, 3, 
    3, 41, 2, 0, 2, 1, 0, 0, 3, 2, 89, 0, 1, 2, 0, 4, 63, 83, 51, 123, 40, 31, 58, 2, 3, 12, 40, 7, 0, 2, 0, 
    111, 5, 59, 0, 0, 0, 0, 30, 0, 2, 0, 1, 8, 5, 63, 0, 2, 5, 1, 1, 3, 0, 29, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 
    0, 0, 0, 0, 187, 1, 1, 0, 0, 25, 0, 0, 0, 0, 0, 183, 0, 180, 0, 1, 12, 35, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 
    0, 14, 13, 0, 106, 0, 0, 23, 11, 56, 76, 23, 40, 6, 0, 0, 2, 3, 1, 0, 96, 52, 3, 18, 55, 19, 4, 15, 27, 0, 
    4, 3, 2, 27, 27, 1, 0, 0, 0, 0, 5, 0, 1, 5, 0, 12, 2, 0, 267, 8, 54, 2, 18, 1, 328, 4, 3, 1, 0, 0, 16, 12, 
    1, 33, 66, 2, 1, 7, 2, 2, 1, 1, 37, 2, 2, 12, 4, 8, 134, 5, 9, 0, 2, 2, 8, 16, 23, 7, 7, 3, 13, 134, 0, 
    0, 11, 0, 18, 3, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 347, 0, 1, 168, 5, 2, 208, 3, 2, 7, 6, 51, 14, 36, 
    5, 3, 11, 235, 14, 444, 134, 2, 0, 2, 11, 12, 63, 9, 485, 1, 1, 7, 158, 11, 2, 199, 1, 12, 89, 7, 0, 3, 
    382, 6, 1, 35, 5, 45, 18, 7, 18, 10, 1, 0, 618
])

bmdm_il1a = np.array([
    291, 84, 29, 55, 117, 67, 34, 27, 111, 4, 7, 69, 16, 1, 24, 17, 2, 102, 43, 84, 15, 3, 7, 4, 153, 64, 13, 
    0, 0, 1, 34, 29, 6, 163, 18, 25, 81, 45, 43, 9, 12, 2, 9, 6, 228, 1, 35, 6, 5, 54, 4, 3, 1, 10, 0, 3, 111, 
    0, 2, 2, 6, 48, 3, 22, 50, 17, 164, 291, 2, 2, 3, 2, 2, 0, 5, 2, 40, 2, 1, 0, 146, 2, 2, 9, 10, 3, 3, 2, 
    14, 5, 77, 7, 62, 65, 3, 7, 37, 7, 20, 1, 7, 25, 10, 4, 45, 106, 99, 91, 124, 16, 27, 23, 15, 177, 304, 14, 
    75, 1, 7, 4, 6, 113, 106, 127, 43, 33, 33, 26, 110, 13, 6, 216, 7, 141, 22, 11, 14, 6
])

raw_2647_il1b = np.array([
    696, 435, 500, 403, 13, 23, 5, 8, 326, 398, 4, 16, 14, 19, 654, 129, 108, 22, 5, 19, 136, 4, 7, 309, 329, 0, 3, 
    15, 21, 29, 150, 89, 60, 405, 48, 327, 238, 71, 46, 102, 706, 543, 457, 190, 1, 406, 151, 199, 54, 87, 25, 224, 
    14, 1, 10, 80, 0, 35, 40, 19, 910, 191, 121, 12, 15, 54, 70, 140, 115, 26, 36, 2, 337, 2, 75, 13, 5, 10, 1, 
    100, 46, 165, 4, 5, 73, 20, 35, 36, 90, 206, 208, 244, 96, 75, 7, 33, 604, 242, 7, 98, 16, 245, 538, 17, 0, 280, 
    7, 1, 27, 202, 3, 33, 439, 1, 0, 1, 1, 360, 36, 196, 219, 223, 0, 160, 524, 24, 42, 0, 0, 276, 534, 153, 217, 
    310, 272, 0, 408, 102, 918, 10, 44, 250, 21, 319, 470, 607, 415, 114, 593, 15, 25, 4, 211, 34, 133, 108, 149, 
    449, 198, 49, 60, 671, 557, 247, 70, 166, 18, 30, 21, 561, 16, 935, 352, 456, 453, 7, 505, 227, 423, 794, 130, 
    98, 173, 156, 28, 324, 176, 290, 30, 507, 677, 309, 13, 414, 60, 487, 43, 645, 354, 286, 670, 135, 501, 242, 31, 
    155, 489, 15, 41, 54, 442, 228, 85, 526, 1003, 10, 381, 404, 483, 21, 345, 274, 31, 96, 71, 422, 939, 700, 795, 
    305, 81, 5, 0, 0, 1, 207, 379, 1384, 10, 221, 303, 1065, 15, 5, 672, 397, 116, 171, 516, 333, 391, 390, 25, 249, 
    3, 296, 635, 79, 395, 27, 35, 198, 179, 373, 165, 11, 2, 8, 201, 68, 109, 12, 227, 743, 751, 711, 334, 27, 87, 
    479, 32, 608, 45, 38, 782, 33, 145, 19, 325, 69, 21, 15, 179, 301, 6, 7, 4, 402, 207, 153, 106, 23, 161, 2, 3, 
    6, 19, 280, 60, 54, 240, 153, 706, 689, 11, 22, 125, 343, 9, 4, 6, 3, 123, 138, 44, 114, 10, 179, 25, 62, 35, 
    12, 21, 17, 4, 1, 23, 8, 2, 233, 2, 172, 107, 3, 3, 0, 8, 162, 5, 24, 99, 336, 260, 60, 88, 37, 5, 608, 357, 7, 
    21, 408, 23, 0, 238, 393, 74, 2, 10, 121, 321, 56, 1, 92, 26, 365, 365, 160, 4, 163, 188, 2, 9, 2, 43, 282, 1, 
    479, 223, 380, 107, 111, 204, 6, 15, 67, 59, 37, 192, 252, 72, 377, 471, 289, 181, 83, 40, 29, 29, 169, 461, 
    33, 2, 494, 377, 131, 335, 26, 6, 3, 385, 414, 59, 96, 131, 177, 11, 2, 26, 154, 188, 12, 152, 344, 483, 94, 
    296, 54, 381, 13, 104, 10, 44, 70, 342, 195, 269, 605, 2, 11, 180, 26, 228, 567, 6, 20, 410, 241, 221, 160, 91, 
    16, 3, 147, 7, 1, 8, 235, 4, 15, 138, 220, 0, 0, 0, 10, 370, 64, 152, 14, 1, 20, 275, 187, 31, 63, 71, 117, 
    307, 90, 10, 258, 0, 47, 0, 21, 0, 10, 290, 2, 1, 15, 4, 16, 34, 495, 425, 58, 619, 498, 29, 599, 173, 682, 
    432, 592, 70, 49, 590, 216, 427, 42, 46, 13, 49, 358, 277, 290, 199, 437, 352, 339, 115, 188, 40, 70, 443, 506, 
    28, 38, 6, 21, 894, 251, 152, 75, 10, 317, 23, 4, 10, 355, 173, 251, 490, 207, 10, 620, 321, 461, 826, 172, 
    375, 4, 12, 212, 94, 185, 30, 732, 588, 417, 82, 2, 255, 5, 31, 6, 872, 704, 77, 851, 1076, 536, 337, 566, 569, 
    840, 930, 628, 447, 705, 21, 204, 712, 561, 128, 756, 806, 1031, 23, 638, 350, 412, 900, 193, 209, 74, 141, 
    429, 590, 953, 1272, 68, 1060, 767, 678, 756, 506, 29, 440, 503, 176, 11, 284, 811, 911, 220, 496, 38, 469, 249, 
    162, 540, 791, 16, 156, 362, 40, 71, 540, 26, 33, 337, 868, 17, 209, 131, 33, 450, 19, 53, 1, 1, 11, 115, 4, 10, 
    3, 5, 711, 435, 291, 457, 3, 0, 1, 4, 4, 2, 4, 231, 387, 2, 301, 9, 410, 587, 51, 64, 4, 243, 136, 367, 164, 19, 
    11, 421, 152, 2, 11, 4, 5, 2, 689, 175, 1, 20, 17, 58, 122, 535, 160, 137, 10, 8, 278, 166, 19, 20, 606, 500, 
    386, 276, 313, 19, 6, 4
])


bmdm_il1b = np.array([
    195, 116, 20, 510, 916, 17, 9, 7, 107, 72, 120, 205, 22, 193, 171, 38, 1, 124, 20, 36, 18, 7, 4, 8, 0, 23, 5, 
    6, 24, 11, 33, 5, 115, 154, 13, 1, 17, 12, 1, 1, 49, 478, 473, 15, 4, 58, 2, 125, 11, 441, 31, 5, 274, 14, 31, 
    3, 5, 4, 67, 97, 35, 35, 2, 283, 24, 84, 838, 460, 2, 124, 156, 43, 222, 10, 20, 54, 58, 193, 15, 1001, 24, 239, 
    136, 18, 14, 99, 541, 592, 20, 23, 5, 0, 18, 60, 16, 1, 9, 86, 2, 8, 17, 175, 58, 108, 11, 8, 1, 93, 355, 556, 
    12, 14, 494, 309, 24, 15, 81, 2, 325, 40, 494, 225, 26, 44, 284, 0, 2, 0, 501, 4, 128, 24, 0, 35, 13, 372, 27, 
    15, 128, 94, 170, 2, 229, 294, 0, 0, 14, 38, 3, 0, 279, 50, 22, 32, 10, 29, 0, 1, 22, 117, 26, 9, 18, 1, 7, 10, 
    109, 85, 14, 159, 195, 32, 6, 1, 4, 66, 152, 142, 64, 156, 213, 93, 102, 0, 24, 4, 8, 7, 66, 13, 82, 2, 0, 10, 
    32, 0, 46, 1, 1, 19, 182, 156, 348, 319, 99, 449, 75, 68, 104, 99, 95, 292, 41, 151, 127, 26, 78, 214, 281, 322, 
    85, 57, 338, 623, 137, 28, 54, 208, 144, 387, 159, 73, 83, 166, 67, 68, 73, 57, 65, 432, 70, 62, 47, 199, 192, 
    162, 219, 181, 224, 216, 149, 144, 30, 80, 186, 62, 42, 44, 41, 17, 20, 61, 22, 400, 232, 273, 155, 198, 126, 
    105, 155, 323, 118, 91, 59, 76, 51, 87, 44, 43, 156, 32, 26, 75, 183, 60, 522, 259, 107, 72, 151, 516, 221, 37, 
    76, 77, 54, 176, 220, 116, 269, 59, 190, 55, 327, 116, 27, 43, 60, 47, 175, 154, 1, 8, 37, 1, 14, 18, 346, 3, 6, 
    0, 137, 3, 11, 36, 86, 305, 71, 278, 4, 44, 44, 343, 174, 193, 28, 41, 44, 82, 50, 4, 5, 43, 69, 47, 146, 63, 0, 
    0, 214, 27, 105, 3, 186, 16, 187, 177, 0, 3, 0, 6, 4, 76, 2, 42, 18, 439, 121, 3, 0, 6, 66, 4, 10, 13, 2, 5, 1, 
    123, 7, 36, 148, 150, 34, 85, 16, 5, 5, 249, 585, 88, 1, 4, 0, 4, 247, 14, 4, 3, 7, 8, 14, 67, 604, 272, 463, 
    130, 60, 32, 5, 1, 20, 4, 75, 12, 125, 62, 0, 744, 51, 1, 585, 133, 96, 36, 4, 0, 26, 10, 86, 60, 116, 40, 81, 
    49, 66, 142, 20, 1, 19, 167, 230, 105, 32, 24, 91, 64, 12, 3, 271, 24, 28, 107, 135, 53, 11, 16, 4, 88, 273, 18, 
    14, 5, 21, 11, 230, 134, 6, 27, 55, 73, 4, 3, 50, 127, 24, 3, 107, 71, 48, 130, 107, 103, 50, 191, 36, 37, 88, 
    308, 17, 87, 104, 48, 30, 2, 35, 44, 24, 13, 59, 21, 119, 380, 10, 58, 40, 9, 703, 205, 195, 6, 370, 280, 155, 
    129, 320, 36, 11, 18, 146, 164, 583, 209, 196, 12, 21, 66, 436, 24, 87, 9, 393, 23, 240, 17, 102, 20, 205, 288, 
    12, 42, 6, 12, 371, 164, 126, 40, 7, 7, 9, 121, 23, 49, 2, 7, 151, 58, 107, 131, 128, 139, 443, 2, 46, 54, 1, 0, 
    10, 64, 83, 55, 2, 2, 1, 357, 3, 14, 174, 25, 28, 35, 16, 174, 82, 57, 84, 5, 5, 63, 127, 77, 42, 14, 173, 225, 
    85, 317, 185, 36, 227, 77, 321, 60, 16, 16, 88, 32, 145, 10, 325, 347, 468, 307, 218, 97, 137, 6, 110, 5, 34, 85, 
    105, 74, 91, 69, 126, 228, 11, 39, 5, 5, 257, 27, 1, 3, 5, 31, 33, 1, 212, 7, 3, 78, 197, 11, 76, 8, 47, 1, 0, 
    16, 66, 183, 128, 4, 90, 4, 9, 126, 140, 84, 85, 6, 2, 0, 2, 0, 2, 64, 196, 11, 49, 4, 19, 8, 59, 78, 2, 246, 
    66, 279, 0, 4, 97, 38, 120, 27, 6, 278, 483, 447, 88, 191, 19, 541, 109, 99, 272, 270, 264, 11, 71, 7, 59, 19, 
    48, 88, 23, 355, 15, 24, 27, 24, 129, 19, 17, 85, 365, 125, 116, 263, 103, 49, 509, 427, 29, 118, 770, 180, 271, 
    188
])


raw_2647_tnfa = np.array([
    66, 258, 120, 408, 506, 144, 206, 188, 359, 307, 182, 54, 376, 279, 125, 216, 6, 201, 395, 277, 209, 218, 261, 
    458, 583, 178, 121, 187, 465, 371, 210, 453, 160, 180, 243, 312, 392, 96, 380, 194, 217, 187, 99, 161, 202, 140, 
    162, 165, 134, 83, 155, 509, 225, 633, 232, 186, 99, 338, 351, 123, 214, 316, 352, 126, 170, 173, 236, 103, 624, 
    368, 97, 325, 199, 242, 150, 417, 367, 243, 140, 153, 102, 188, 111, 335, 147, 135, 201, 333, 347, 97, 59, 231, 
    83, 138, 313, 175, 239, 25, 475, 301, 331, 326, 263, 426, 109, 214, 297, 204, 441, 355, 65, 247, 276, 350, 325, 
    808, 165, 53, 279, 420, 75, 679, 372, 118, 282, 171, 461, 196, 330, 470, 122, 399, 543, 211, 100, 121, 300, 90, 
    162, 400, 423, 480, 155, 313, 385, 476, 272, 299, 274, 281, 136, 177, 423, 119, 159, 431, 398, 282, 684, 119, 
    129, 256, 49, 172, 140, 231, 319, 299, 131, 216, 252, 130, 55, 35, 405, 50, 36, 213, 278, 389, 473, 432, 132, 135, 
    101, 100, 457, 331, 311, 170, 36, 214, 130, 229, 372, 387, 181, 159, 238, 329, 236, 259, 493, 8, 440, 203, 554, 
    275, 373, 451, 591, 257, 249, 151, 80, 843, 431, 244, 242, 484, 831, 520, 837, 359, 804, 848, 805, 215, 435, 415, 
    660, 189, 261, 583, 457, 715, 458, 715, 397, 559, 268, 172, 348, 473, 551, 472, 556, 410, 296, 373, 900, 403, 
    264, 431, 426, 439, 320, 116, 331, 337, 621, 380, 182, 134, 324, 764, 737, 273, 355, 154, 426, 295, 821, 168, 
    229, 328, 144, 218, 516, 319, 392, 256, 546, 460, 281, 808, 445, 448, 212, 1006, 267, 206, 616, 311, 126, 321, 
    171, 174, 631, 551, 488, 623, 253, 218, 292, 259, 221, 465, 339, 426, 925, 241, 286, 596, 1334, 692, 205, 306, 
    476, 665, 466, 512, 392, 276, 367, 1018, 814, 89, 281, 670, 428, 332, 755, 568, 500, 512, 316, 448, 318, 521, 
    256, 608, 251, 364, 1001, 715, 612, 402, 612, 580, 672, 671, 417, 369, 158, 393
])


bmdm_tnfa = np.array([
    150, 207, 199, 247, 185, 445, 309, 80, 98, 106, 135, 201, 18, 253, 335, 562, 551, 99, 325, 7, 503, 286, 183, 
    193, 174, 131, 142, 103, 340, 410, 326, 65, 222, 112, 306, 119, 55, 91, 177, 280, 342, 273, 226, 188, 72, 257, 
    72, 36, 19, 161, 11, 6, 2, 15, 94, 34, 100, 110, 48, 7, 5, 31, 479, 244, 238, 519, 851, 378, 435, 161, 541, 
    515, 409, 376, 484, 275, 228, 714, 117, 501, 289, 767, 262, 615, 204, 318, 791, 226, 166, 329, 543, 458, 486, 
    404, 281, 398, 468, 390, 170, 279, 462, 133, 362, 206, 267, 365, 358, 157, 247, 465, 348, 232, 176, 199, 371, 
    225, 238, 565, 213, 122, 53, 400, 149, 199, 367, 260, 690, 164, 138, 469, 132, 240, 375, 320, 212, 258, 75, 563, 
    170, 259, 497, 349, 238, 407, 65, 310, 501, 558, 218, 570, 258, 177, 62, 188, 412, 224, 407, 453, 386, 268, 675, 
    304, 191, 166, 323, 49, 182, 223, 274, 313, 216, 376, 169, 439, 168, 140, 179, 407, 113, 150, 211, 198, 248, 
    263, 171, 194, 137, 331, 369, 319, 159, 407, 108, 176, 268, 304, 119, 186, 136, 174, 166, 445, 299, 279, 267, 
    249, 210, 276, 264, 208, 401, 389, 526, 402, 213, 445, 234, 527, 175, 219, 416, 434, 294, 285, 124, 164, 319, 
    350, 362, 280, 287, 247, 312, 106, 278, 258, 306, 230, 356, 286, 71, 193, 28, 40, 208, 240, 128, 12, 133, 279, 
    318, 143, 59, 249, 241, 106, 278, 130, 111, 195, 434, 363, 168, 79, 269, 194, 183, 103, 157, 255, 56, 81, 222, 
    165, 248, 133, 20, 88, 117, 57, 152, 67, 36, 48, 166, 246, 81, 16, 76, 19, 26, 25, 64, 52, 88, 29, 80, 66, 102, 
    420, 69, 103, 58, 30, 14, 319, 20, 64, 96, 99, 306, 192, 161, 30, 292, 147, 63, 312, 208, 137, 24, 32
])

num_features = 1 

# Sort the data
raw_sorted_il1a = np.sort(raw_2647_il1a)
bmdm_sorted_il1a = np.sort(bmdm_il1a)

raw_sorted_il1b = np.sort(raw_2647_il1b)
bmdm_sorted_il1b = np.sort(bmdm_il1b)

raw_sorted_tnfa = np.sort(raw_2647_tnfa)
bmdm_sorted_tnfa = np.sort(bmdm_tnfa)

def filter_data(data):
    valid_indices = (data <= 10**15) & (data >= 10**-6)
    return data[valid_indices]

raw_sorted_filtered_il1a = filter_data(raw_sorted_il1a)
bmdm_sorted_filtered_il1a = filter_data(bmdm_sorted_il1a)

raw_sorted_filtered_il1b = filter_data(raw_sorted_il1b)
bmdm_sorted_filtered_il1b = filter_data(bmdm_sorted_il1b)

raw_sorted_filtered_tnfa = filter_data(raw_sorted_tnfa)
bmdm_sorted_filtered_tnfa = filter_data(bmdm_sorted_tnfa)

def adjust_sequence_length(data, seq_len):
    #print(len(data))
    if len(data) < seq_len:
        repeat_factor = seq_len // len(data) + 1
        data = np.tile(data, repeat_factor)[:seq_len]
    else:
        data = data[:seq_len]
    return data

seq_len = 235
raw_sorted_filtered_il1a = adjust_sequence_length(raw_sorted_filtered_il1a, seq_len)
seq_len = 132
bmdm_sorted_filtered_il1a = adjust_sequence_length(bmdm_sorted_filtered_il1a, seq_len)

def create_tensors(seq_len, num_features):
    x = torch.zeros(seq_len, 2, num_features)
    y = torch.zeros(seq_len, 2)
    return x, y

seq_len = 699
raw_sorted_filtered_il1b = adjust_sequence_length(raw_sorted_filtered_il1b, seq_len)
seq_len = 706
bmdm_sorted_filtered_il1b = adjust_sequence_length(bmdm_sorted_filtered_il1b, seq_len)


seq_len = 356
raw_sorted_filtered_tnfa = adjust_sequence_length(raw_sorted_filtered_tnfa, seq_len)
seq_len = 322
bmdm_sorted_filtered_tnfa = adjust_sequence_length(bmdm_sorted_filtered_tnfa, seq_len)


In [40]:
def plot_kinpfn_on_selected_testing_set_seq(trained_model, raw_sorted_filtered_il1a, bmdm_sorted_filtered_il1a, raw_sorted_filtered_il1b,bmdm_sorted_filtered_il1b,raw_sorted_filtered_tnfa,bmdm_sorted_filtered_tnfa, seed=None):
    if seed is None:
        seed = random.randint(0, 10000)
    print(f"Seed: {seed}")
    set_seed(seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    training_point = 25

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6))
    fig.set_dpi(300)

    handles = []
    labels = []

    for rna in ["IL1a", "IL1b", "TNFa"]:
        
        if rna == "IL1a":
            seq_len_raw_bmdm = [235, 132]
            target_color = '#6EC4E8'
            pred_color = '#1F77B4'
            label = 'IL1α'
            raw_sorted_filtered = raw_sorted_filtered_il1a
            bmdm_sorted_filtered = bmdm_sorted_filtered_il1a
        elif rna == "IL1b":
            seq_len_raw_bmdm = [699, 706]
            target_color = '#1D3E99'
            pred_color = '#9467bd'
            label = 'IL1β'
            raw_sorted_filtered = raw_sorted_filtered_il1b
            bmdm_sorted_filtered = bmdm_sorted_filtered_il1b
        elif rna == "TNFa":
            seq_len_raw_bmdm = [356, 322]
            target_color = '#D8342C'
            pred_color = '#D62728'
            label = 'TNFα'
            raw_sorted_filtered = raw_sorted_filtered_tnfa
            bmdm_sorted_filtered = bmdm_sorted_filtered_tnfa
        else:
            raise ValueError("Invalid RNA marker. Choose from 'IL1a', 'IL1b', or 'TNFa'.")

        seq_len = seq_len_raw_bmdm[0]
        raw_2647_data_x = torch.zeros(seq_len, 1, dtype=torch.float32).to(device)
        seq_len = seq_len_raw_bmdm[1]
        bmdm_data_x = torch.zeros(seq_len, 1, dtype=torch.float32).to(device)
        
        raw_2647_data_y = torch.tensor(raw_sorted_filtered, dtype=torch.float32).to(device)
        raw_2647_data_y_log10 = torch.tensor(np.log10(raw_sorted_filtered), dtype=torch.float32).to(device)
        
        bmdm_data_y = torch.tensor(bmdm_sorted_filtered, dtype=torch.float32).to(device)
        bmdm_data_y_log10 = torch.tensor(np.log10(bmdm_sorted_filtered), dtype=torch.float32).to(device)

        datasets = [(raw_2647_data_x, raw_2647_data_y_log10, raw_2647_data_y), 
                    (bmdm_data_x, bmdm_data_y_log10, bmdm_data_y)]
        
        dataset_names = ["RAW 264.7", "BMDM"]

        for evaluations, (x, y_log10, y) in enumerate(datasets):
            col = evaluations
            ax = axes[col]
            
            ax.set_title(f"{dataset_names[evaluations]}", fontsize=16)
            seq_len = seq_len_raw_bmdm[evaluations]
            train_indices = torch.randperm(seq_len)[:training_point]
            
            train_x = x[train_indices]
            train_y_log10 = y_log10[train_indices]
            test_x = x
            test_y_log10 = y_log10

            train_x = train_x.to(device)
            train_y_log10 = train_y_log10.to(device)
            test_x = test_x.to(device)
            test_y_log10 = test_y_log10.to(device)

            with torch.no_grad():
                # Pass training and test data through the model
                logits = trained_model(train_x[:, None], train_y_log10[:, None], test_x[:, None])

            ground_truth_sorted, _ = torch.sort(test_y_log10)
            ground_truth_cdf = torch.arange(1, len(ground_truth_sorted) + 1) / len(ground_truth_sorted)

            linspace_extended = torch.linspace(train_y_log10.min() - 1, train_y_log10.max() + 1, 1000)
            pred_cdf_linspace_extended = trained_model.criterion.cdf(logits, linspace_extended)[0][0]

            # Plot results
            scatter1 = ax.scatter(10**ground_truth_sorted, ground_truth_cdf, color=target_color, marker="x", label=f"Target {label}", alpha=0.5)
            line2, = ax.plot(10**linspace_extended, pred_cdf_linspace_extended, color=target_color, marker=".", label=f"KinPFN {label}")
            scatter3 = ax.scatter(10**train_y_log10, torch.zeros_like(train_y_log10), color=target_color, marker="o", label=f"Context {label}")

            if evaluations == 0:
                handles.append(scatter1)
                handles.append(line2)
                handles.append(scatter3)
                labels.extend([f"Target {label}", f"KinPFN {label}", f"Context {label}"])

        axes[0].set_ylabel("Cumulative Probability", fontsize=18)
        for ax in axes:
            ax.set_xscale("log")
            ax.set_xlabel(r'$\log_{10}$(mRNA + 1)', fontsize=18)

            ax.tick_params(axis='both', which='major', labelsize=12)

    fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=14, bbox_to_anchor=(0.5, -0.1))

    plt.tight_layout()
    plt.subplots_adjust(bottom=0.2)
    plt.show()


In [41]:
def plot_kinpfn_on_selected_testing_set_seq_appendix(trained_model, raw_sorted_filtered_il1a, bmdm_sorted_filtered_il1a, raw_sorted_filtered_il1b,bmdm_sorted_filtered_il1b,raw_sorted_filtered_tnfa,bmdm_sorted_filtered_tnfa, seed=None):
    if seed is None:
        seed = random.randint(0, 10000)
    print(f"Seed: {seed}")
    set_seed(seed)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    training_points = [10, 25, 50, 75]

    fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(30, 12))
    fig.set_dpi(300)


    for training_point in training_points:

        handles = []
        labels = []

        for rna in ["IL1a", "IL1b", "TNFa"]:
            
            if rna == "IL1a":
                seq_len_raw_bmdm = [235, 132]
                target_color = '#6EC4E8'
                pred_color = '#1F77B4'
                label = 'IL1α'
                raw_sorted_filtered = raw_sorted_filtered_il1a
                bmdm_sorted_filtered = bmdm_sorted_filtered_il1a
            elif rna == "IL1b":
                seq_len_raw_bmdm = [699, 706]
                target_color = '#1D3E99'
                pred_color = '#9467bd'
                label = 'IL1β'
                raw_sorted_filtered = raw_sorted_filtered_il1b
                bmdm_sorted_filtered = bmdm_sorted_filtered_il1b
            elif rna == "TNFa":
                seq_len_raw_bmdm = [356, 322]
                target_color = '#D8342C'
                pred_color = '#D62728'
                label = 'TNFα'
                raw_sorted_filtered = raw_sorted_filtered_tnfa
                bmdm_sorted_filtered = bmdm_sorted_filtered_tnfa
            else:
                raise ValueError("Invalid RNA marker. Choose from 'IL1a', 'IL1b', or 'TNFa'.")

            seq_len = seq_len_raw_bmdm[0]
            raw_2647_data_x = torch.zeros(seq_len, 1, dtype=torch.float32).to(device)
            seq_len = seq_len_raw_bmdm[1]
            bmdm_data_x = torch.zeros(seq_len, 1, dtype=torch.float32).to(device)
            
            raw_2647_data_y = torch.tensor(raw_sorted_filtered, dtype=torch.float32).to(device)
            raw_2647_data_y_log10 = torch.tensor(np.log10(raw_sorted_filtered), dtype=torch.float32).to(device)
            
            bmdm_data_y = torch.tensor(bmdm_sorted_filtered, dtype=torch.float32).to(device)
            bmdm_data_y_log10 = torch.tensor(np.log10(bmdm_sorted_filtered), dtype=torch.float32).to(device)

            datasets = [(raw_2647_data_x, raw_2647_data_y_log10, raw_2647_data_y), 
                        (bmdm_data_x, bmdm_data_y_log10, bmdm_data_y)]
            
            dataset_names = ["RAW 264.7", "BMDM"]

            for evaluations, (x, y_log10, y) in enumerate(datasets):
                col = training_points.index(training_point)
                row = evaluations
                ax = axes[row, col]
                
                ax.set_title(f"{dataset_names[evaluations]}", fontsize=16)
                seq_len = seq_len_raw_bmdm[evaluations]

                train_indices = torch.randperm(seq_len)[:training_point]
                
                # Split data into training and testing
                train_x = x[train_indices]
                train_y_log10 = y_log10[train_indices]
                test_x = x
                test_y_log10 = y_log10

                train_x = train_x.to(device)
                train_y_log10 = train_y_log10.to(device)
                test_x = test_x.to(device)
                test_y_log10 = test_y_log10.to(device)

                with torch.no_grad():
                    # Pass training and test data through the model
                    logits = trained_model(train_x[:, None], train_y_log10[:, None], test_x[:, None])

                ground_truth_sorted, _ = torch.sort(test_y_log10)
                ground_truth_cdf = torch.arange(1, len(ground_truth_sorted) + 1) / len(ground_truth_sorted)

                linspace_extended = torch.linspace(train_y_log10.min() - 1, train_y_log10.max() + 1, 1000)
                pred_cdf_linspace_extended = trained_model.criterion.cdf(logits, linspace_extended)[0][0]

                # Plot results
                scatter1 = ax.scatter(10**ground_truth_sorted, ground_truth_cdf, color=target_color, marker="x", label=f"Target {label}", alpha=0.5)
                line2, = ax.plot(10**linspace_extended, pred_cdf_linspace_extended, color=target_color, marker=".", label=f"KinPFN {label}")
                scatter3 = ax.scatter(10**train_y_log10, torch.zeros_like(train_y_log10), color=target_color, marker="o", label=f"Context {label}")

                if evaluations == 0:
                    handles.append(scatter1)
                    handles.append(line2)
                    handles.append(scatter3)
                    labels.extend([f"Target {label}", f"KinPFN {label}", f"Context {label}"])

                ax.set_xscale("log")
                ax.set_xlabel(r'$\log_{10}$(mRNA + 1)', fontsize=18)

                ax.tick_params(axis='both', which='major', labelsize=12)

    axes[0,0].set_ylabel("Cumulative Probability", fontsize=18)
    axes[1,0].set_ylabel("Cumulative Probability", fontsize=18)

    fig.legend(handles, labels, loc='lower center', ncol=3, fontsize=25, bbox_to_anchor=(0.5, -0.15), markerscale=2)

    plt.tight_layout()
    plt.show()

In [None]:
plot_kinpfn_on_selected_testing_set_seq(trained_model, raw_sorted_filtered_il1a, bmdm_sorted_filtered_il1a, raw_sorted_filtered_il1b, bmdm_sorted_filtered_il1b, raw_sorted_filtered_tnfa, bmdm_sorted_filtered_tnfa, seed=8986)
plot_kinpfn_on_selected_testing_set_seq_appendix(trained_model, raw_sorted_filtered_il1a, bmdm_sorted_filtered_il1a, raw_sorted_filtered_il1b, bmdm_sorted_filtered_il1b, raw_sorted_filtered_tnfa, bmdm_sorted_filtered_tnfa, seed=107)