In [1]:
import argparse
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pytorch_pretrained_bert import BertAdam

from utils.helpers import get_data_loaders
from models import get_model
from utils.logger import create_logger
from utils.utils import *

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import sys

In [2]:
args = argparse.Namespace(
        bert_model="bert-base-uncased",
        data_path="data_sets/proc/",
        task="moviescope",
        model="mmbt",
        batch_sz=8,
        task_type="multilabel",
        aug=0,
        drop_img_percent=0,
        max_seq_len=512,
        num_image_embeds=3,
        n_workers=8,
        poster="raw"
    )


In [3]:
train_loader, val_loader, test_loader = get_data_loaders(args)

In [4]:
from itertools import cycle

In [5]:
for batch in tqdm(zip(train_loader,cycle(val_loader)), total=len(train_loader)):
    txt, segment, mask, img, img_aug, tgt, img_id= batch
    print(img_id)
    pass

  0%|          | 1/432 [00:00<06:16,  1.14it/s]

[2727, 238, 4340, 2057, 2928, 4721, 2225, 3006]
[4747, 2370, 3638, 8, 4857, 896, 3064, 4050]


  2%|▏         | 9/432 [00:01<00:39, 10.79it/s]

[2347, 1058, 4295, 1461, 2527, 3687, 3850, 2848]
[2625, 3280, 1575, 0, 2198, 4321, 98, 4177]
[2798, 1380, 701, 4446, 2828, 4505, 1413, 3924]
[4951, 1529, 4084, 2477, 2818, 3487, 220, 3136]
[2973, 4756, 1237, 3092, 4490, 3094, 2247, 3433]
[3372, 2371, 660, 1249, 934, 2267, 2335, 25]
[2554, 2268, 3841, 1722, 3696, 1242, 2764, 1128]
[3169, 426, 4893, 851, 4751, 67, 2742, 2646]


  3%|▎         | 12/432 [00:01<00:31, 13.54it/s]

[1979, 3901, 990, 1624, 388, 681, 3922, 3605]
[1733, 1669, 3473, 4993, 3807, 1269, 4960, 3580]
[2933, 3551, 400, 4553, 3196, 1542, 2642, 1942]
[1134, 4485, 950, 2952, 1638, 3546, 3782, 4304]
[1212, 1319, 1391, 3699, 4734, 845, 4005, 2780]
[4246, 1312, 558, 4105, 1901, 1656, 4482, 4891]
[3191, 1346, 2039, 372, 2311, 709, 1021, 499]
[1375, 4577, 3274, 4428, 124, 4768, 945, 1745]


  4%|▍         | 19/432 [00:01<00:20, 19.70it/s]

[1941, 3466, 674, 3853, 2419, 2021, 228, 570]
[1116, 4308, 1291, 3162, 2275, 1224, 1138, 4410]
[3443, 4525, 2123, 200, 2964, 3501, 3564, 2988]
[1597, 3378, 177, 1934, 2153, 4810, 1355, 2545]
[3891, 1794, 2762, 3248, 4506, 661, 2889, 1399]
[3576, 5, 189, 509, 1598, 1944, 2456, 560]
[1511, 3513, 4395, 1773, 3175, 1770, 2145, 1196]
[1951, 1259, 2028, 4759, 683, 547, 1181, 3711]


 10%|▉         | 43/432 [00:01<00:08, 44.08it/s]

[2504, 172, 3108, 1360, 1658, 4374, 1731, 2967]
[2908, 1707, 1475, 2468, 1414, 4489, 2677, 1639]
[4783, 386, 3241, 4472, 490, 3562, 4852, 2014]
[2537, 559, 156, 1518, 1873, 4447, 929, 1324]
[4259, 712, 307, 3107, 2572, 2726, 609, 87]
[835, 2808, 2144, 4772, 4160, 740, 2577, 626]
[1763, 1537, 3503, 2443, 4171, 1307, 4421, 2427]
[1968, 2982, 1411, 579, 4437, 204, 344, 1987]
[936, 3645, 1059, 1463, 2433, 1305, 141, 2137]
[7, 3231, 4475, 2414, 4994, 4724, 229, 2671]
[2330, 4961, 4855, 2636, 4402, 2172, 2604, 4401]
[3581, 4929, 58, 1185, 2118, 1139, 4664, 4870]
[4871, 3963, 4844, 2373, 3090, 4793, 2241, 3844]
[1666, 4436, 4339, 4110, 312, 3754, 2817, 2876]
[2400, 4366, 2420, 265, 1513, 501, 4786, 508]
[1054, 2696, 1020, 4778, 1800, 285, 3610, 4322]
[525, 2533, 858, 2344, 254, 2217, 2289, 3052]
[1662, 117, 667, 4390, 4026, 4144, 2804, 3186]
[449, 1650, 3649, 3143, 1281, 4712, 176, 4604]
[2257, 485, 805, 4170, 1615, 2083, 244, 1079]
[895, 236, 1586, 2662, 1530, 1086, 4128, 295]
[404, 1886, 48

 13%|█▎        | 56/432 [00:02<00:07, 50.08it/s]

[3881, 414, 2831, 3778, 3789, 4258, 1063, 2051]
[2176, 1956, 1416, 62, 2033, 1742, 82, 1072]
[4478, 3542, 3093, 1226, 4464, 3833, 4507, 4392]
[477, 4618, 3761, 2989, 2927, 1393, 368, 2725]
[4089, 1012, 2721, 2705, 3460, 40, 4269, 4375]
[549, 1607, 4095, 4093, 1541, 135, 3843, 3389]


 14%|█▍        | 62/432 [00:02<00:07, 50.71it/s]

[891, 4971, 4861, 3290, 1558, 3865, 4945, 4252]
[3109, 4054, 1067, 1434, 4467, 917, 4647, 923]
[3325, 2722, 3693, 4661, 4262, 3225, 4789, 335]
[3664, 4719, 3588, 460, 4854, 4668, 3479, 4616]


 16%|█▌        | 68/432 [00:02<00:08, 41.61it/s]

[3690, 3178, 3892, 1573, 1278, 2379, 2260, 3947]
[3505, 2097, 1641, 699, 2589, 1287, 771, 1432]
[1723, 166, 1404, 3952, 4229, 781, 503, 3909]
[3744, 3113, 1519, 643, 1084, 2232, 3737, 2350]
[3948, 1524, 1911, 2334, 2924, 3941, 3068, 5032]
[3727, 639, 1563, 991, 3011, 3773, 4641, 182]
[2485, 1439, 4817, 3226, 1468, 4185, 4555, 1451]
[1465, 3629, 4104, 2120, 2421, 4534, 2513, 1964]
[2489, 3956, 4495, 1129, 1055, 4122, 3419, 2949]
[2588, 3121, 1709, 61, 321, 3524, 4573, 3801]
[2620, 3550, 3201, 4558, 2905, 937, 2212, 1097]
[356, 2285, 3238, 3279, 4425, 1550, 4178, 4380]
[2688, 1821, 527, 3423, 3033, 3640, 2857, 3379]
[2413, 4279, 24, 281, 3358, 2501, 995, 1029]


 18%|█▊        | 79/432 [00:02<00:10, 33.51it/s]

[4804, 3149, 4809, 2627, 1589, 1939, 2736, 1584]
[4889, 3916, 967, 115, 1812, 2067, 2459, 3533]
[2, 183, 3862, 2364, 3771, 4775, 3424, 453]
[3436, 1696, 1347, 3676, 1655, 4190, 909, 106]
[2903, 4156, 2310, 1304, 561, 1444, 3811, 944]
[467, 1867, 3722, 1357, 2659, 2266, 4193, 3394]
[1137, 4705, 3720, 2592, 1514, 2345, 2535, 2516]
[3797, 1622, 670, 809, 2546, 4065, 1536, 3989]


 20%|██        | 87/432 [00:03<00:10, 33.20it/s]

[3133, 4445, 4956, 4513, 647, 3554, 3914, 4671]
[4291, 532, 314, 505, 2238, 1628, 4370, 4670]
[3911, 3658, 3430, 1296, 1280, 354, 4233, 3735]
[169, 4297, 3376, 3398, 3078, 2093, 3293, 2766]


 25%|██▍       | 107/432 [00:03<00:08, 37.86it/s]

[1160, 472, 6, 2243, 3976, 3747, 3692, 4595]
[3032, 971, 3170, 3598, 1052, 2975, 2180, 919]
[2425, 3319, 309, 1995, 190, 1448, 855, 1703]
[2274, 2473, 5010, 726, 2316, 425, 1708, 2648]
[1273, 3514, 572, 4488, 3873, 3798, 2626, 487]
[3752, 249, 693, 1551, 2186, 1605, 3714, 687]
[3758, 3140, 3104, 2065, 3599, 2953, 464, 1757]
[761, 2471, 5026, 1107, 3969, 1828, 745, 194]
[205, 4398, 1561, 1953, 3243, 3975, 3535, 2071]
[1580, 1854, 2651, 2312, 636, 3482, 671, 2827]
[5000, 3046, 3805, 3316, 2324, 3998, 2530, 3661]
[578, 2562, 277, 1301, 1790, 1983, 3060, 3458]
[2309, 2847, 5015, 1109, 654, 1071, 5009, 4780]
[1798, 3381, 3223, 3955, 568, 3268, 1683, 4041]
[3984, 1985, 81, 4743, 4659, 1816, 2581, 2936]
[3652, 2152, 4288, 4261, 1884, 869, 4357, 2293]
[4683, 1931, 816, 3958, 4307, 3086, 186, 215]
[3062, 201, 2015, 110, 308, 2230, 3130, 3529]
[45, 2220, 3651, 4263, 1540, 4317, 286, 778]
[3967, 4192, 2509, 849, 225, 2130, 714, 4603]
[4352, 898, 1028, 2555, 2565, 153, 3925, 1863]
[4184, 2328, 286

 27%|██▋       | 115/432 [00:03<00:07, 43.52it/s]

[337, 1833, 102, 1735, 1147, 4199, 690, 3864]


 30%|██▉       | 129/432 [00:04<00:07, 39.97it/s]

[2410, 361, 434, 151, 2644, 4820, 323, 814]
[3048, 4083, 3260, 2249, 4779, 4133, 2558, 320]
[207, 2117, 1686, 2041, 996, 2352, 617, 5034]
[3992, 798, 2076, 1149, 2698, 3455, 1644, 4396]
[1421, 4589, 766, 2643, 2984, 1972, 3849, 4232]
[2784, 736, 721, 3704, 2113, 817, 1190, 1486]
[2657, 2133, 111, 3245, 4382, 479, 4551, 3029]
[2830, 2154, 1739, 4098, 161, 1458, 4135, 231]
[1082, 548, 2514, 4843, 3636, 1384, 405, 2629]
[4838, 4575, 3355, 3267, 2681, 4197, 2438, 3495]
[3053, 799, 1811, 2824, 1174, 3525, 497, 4336]
[697, 4236, 504, 3299, 1779, 2865, 834, 1321]
[1822, 2056, 1158, 1292, 2264, 3668, 3772, 2877]
[1311, 2526, 1474, 768, 4345, 3454, 3806, 696]
[1526, 4642, 550, 352, 3309, 3779, 4162, 829]
[4773, 484, 2401, 4710, 3145, 4077, 3609, 893]
[5012, 1231, 2550, 3917, 2406, 3814, 4113, 3250]
[751, 3820, 3200, 3323, 839, 3429, 2797, 3949]
[168, 2463, 1667, 840, 4150, 262, 1929, 1883]
[4055, 103, 1793, 43, 538, 836, 3180, 4347]
[3167, 4020, 4528, 3734, 1320, 4931, 993, 3756]
[1719, 1157, 6

 36%|███▌      | 156/432 [00:04<00:03, 69.73it/s]

[2355, 1544, 2532, 3300, 3470, 3619, 1506, 2799]
[199, 420, 1482, 458, 1179, 775, 4351, 1038]
[1410, 1271, 366, 987, 3181, 2715, 2446, 96]
[1918, 2774, 27, 3100, 524, 1156, 395, 2754]
[4729, 938, 1613, 619, 4535, 2868, 358, 3800]
[4206, 4214, 1497, 4241, 3327, 1213, 4761, 4362]
[4173, 4830, 926, 3656, 1466, 2089, 2807, 3822]
[980, 1354, 438, 2773, 2922, 2313, 3517, 4329]
[4378, 4319, 1606, 4294, 1368, 3510, 2029, 171]
[494, 2785, 4477, 280, 3913, 4588, 3556, 10]
[1176, 4799, 3474, 2333, 1283, 4989, 1001, 1177]
[2487, 2999, 2038, 4331, 4905, 1085, 2223, 2119]
[351, 1081, 2418, 2940, 4997, 4068, 2917, 551]
[2966, 706, 475, 1069, 333, 630, 57, 1188]
[4030, 4315, 2281, 3328, 975, 4337, 5035, 622]
[3193, 1853, 3918, 797, 3096, 3059, 3073, 1938]
[586, 2095, 498, 685, 237, 684, 3341, 3579]
[3219, 1419, 979, 2151, 1227, 1308, 3457, 146]
[1268, 589, 1819, 4954, 1362, 4406, 3417, 4112]
[3872, 903, 4862, 3890, 2957, 625, 3628, 2854]
[1673, 53, 4130, 1751, 872, 3544, 3625, 3590]
[129, 4860, 2185, 

 44%|████▎     | 188/432 [00:04<00:02, 104.10it/s]

[598, 253, 3101, 4824, 2521, 4484, 444, 26]
[2800, 1977, 4691, 785, 3322, 3688, 2246, 1257]
[810, 3400, 2402, 1251, 2143, 4607, 5036, 4953]
[4943, 2105, 4358, 3435, 3537, 4966, 4225, 4031]
[645, 3294, 803, 4869, 1381, 638, 3781, 290]
[46, 2734, 4682, 1415, 3538, 3595, 4690, 2660]
[1433, 1755, 1986, 2947, 2633, 20, 1879, 54]
[3122, 1328, 1327, 3468, 2717, 3884, 415, 4622]
[2796, 4649, 2240, 4686, 2524, 3129, 422, 4159]
[2959, 3384, 708, 744, 801, 1778, 3701, 1515]
[4917, 843, 2178, 4243, 3848, 3002, 2978, 4276]
[2690, 4716, 3698, 4194, 4680, 3895, 4847, 2702]
[1233, 4441, 4350, 4925, 4652, 398, 4087, 411]
[3785, 2720, 2403, 1018, 3792, 1671, 596, 299]
[3565, 1501, 4103, 3114, 2184, 2755, 1971, 3120]
[1787, 4172, 2394, 3320, 3824, 4220, 2499, 2691]
[3571, 1100, 4950, 2563, 38, 3935, 4886, 2036]
[4922, 4014, 4469, 1246, 2849, 1998, 97, 2612]
[4223, 3021, 974, 3972, 4892, 5025, 2855, 1382]
[4531, 2918, 1756, 1792, 399, 3066, 3608, 2757]
[92, 283, 3164, 2601, 1102, 3930, 4006, 1253]
[2091, 

 51%|█████     | 220/432 [00:04<00:01, 122.38it/s]

[759, 3527, 4568, 3855, 3907, 3828, 1493, 83]
[4456, 4559, 2342, 3803, 4777, 1183, 4400, 4058]
[3218, 4906, 776, 1578, 1389, 4585, 4372, 2801]
[3236, 148, 652, 2188, 733, 1186, 3340, 248]
[3080, 4365, 1842, 3726, 2044, 2895, 3022, 2823]
[1006, 63, 2929, 969, 862, 1111, 2388, 4630]
[4403, 2417, 4572, 125, 593, 407, 965, 4222]
[136, 3353, 4674, 3023, 1306, 4011, 3589, 4651]
[3277, 446, 2566, 3356, 2294, 762, 3990, 3217]
[1553, 3075, 188, 3063, 2138, 1108, 4388, 2867]
[3405, 616, 2301, 724, 4102, 4119, 4186, 3743]
[1913, 392, 2357, 4234, 4958, 951, 2191, 2897]
[4792, 1467, 3713, 3685, 599, 4986, 4514, 4097]
[2179, 2733, 4977, 899, 3637, 1887, 4442, 4126]
[3362, 3647, 4153, 2792, 3301, 34, 2825, 1104]
[1775, 1729, 2539, 922, 649, 4306, 4754, 3665]
[1255, 3802, 1768, 750, 30, 1900, 4045, 3076]
[1400, 1653, 959, 575, 3253, 743, 841, 2753]
[3132, 3887, 4678, 4709, 3534, 1112, 4584, 2010]
[1614, 2190, 3496, 927, 3549, 3125, 196, 3079]
[2399, 4504, 222, 1675, 4704, 3447, 4074, 500]
[4281, 4901,

 57%|█████▋    | 248/432 [00:05<00:01, 121.36it/s]

[119, 100, 2209, 2483, 223, 658, 181, 3842]
[4231, 3502, 739, 657, 1980, 1473, 4806, 2011]
[3202, 353, 162, 1661, 4354, 4896, 1235, 4822]
[2398, 2035, 3050, 4457, 3753, 3119, 3183, 3246]
[2062, 4900, 2132, 2458, 3397, 4397, 3375, 1876]
[1626, 3281, 4060, 583, 1376, 4932, 1636, 1361]
[2919, 1477, 3264, 1428, 3520, 2900, 4167, 4646]
[2880, 3098, 2594, 2327, 2080, 1239, 1089, 2340]
[3117, 131, 362, 1333, 1470, 4620, 2129, 1200]
[1049, 2048, 2741, 2540, 3728, 3937, 232, 1064]
[1838, 5005, 2759, 2194, 1454, 473, 2680, 2381]
[2164, 3999, 271, 1488, 695, 1533, 4201, 4658]
[3700, 1370, 4788, 1874, 2495, 1936, 1860, 3160]
[602, 1077, 1367, 3632, 824, 2060, 2624, 1]
[635, 107, 3725, 4013, 2960, 257, 2218, 1823]
[439, 2273, 3821, 4053, 985, 456, 916, 4127]
[2769, 3172, 2434, 2575, 4381, 3425, 1155, 4864]
[1691, 1387, 3985, 2346, 2863, 2055, 716, 440]
[2101, 3596, 288, 2888, 1300, 3095, 4920, 1593]
[4140, 3258, 2103, 2843, 3812, 952, 3171, 1019]
[2811, 3412, 217, 1123, 4614, 610, 4877, 1952]
[2576

 64%|██████▍   | 277/432 [00:05<00:01, 121.42it/s]

[1990, 1700, 2976, 3738, 4090, 2910, 1976, 2861]
[2777, 219, 764, 2549, 1023, 941, 3960, 4728]
[1254, 1003, 3009, 1562, 3627, 1144, 2457, 1159]
[3882, 4602, 2704, 732, 142, 1195, 2632, 3047]
[4438, 1132, 326, 3603, 800, 1406, 4914, 3799]
[1441, 384, 2474, 4466, 3213, 4980, 4510, 2278]
[3489, 2599, 4224, 1803, 644, 1500, 2573, 4239]
[261, 448, 737, 1169, 4617, 4944, 3795, 3851]
[2610, 462, 2510, 3483, 2944, 94, 18, 1245]
[4313, 104, 112, 150, 2668, 75, 341, 795]
[4831, 2375, 1836, 4280, 4735, 224, 4557, 1846]
[1743, 1663, 2256, 3150, 2761, 3750, 3359, 291]
[2323, 3680, 3897, 1601, 4468, 1113, 2050, 2716]
[1377, 1040, 1096, 230, 2606, 339, 1959, 3155]
[1904, 3736, 1590, 2707, 1872, 2816, 442, 2850]
[2242, 2317, 350, 2219, 1175, 2934, 2292, 1988]
[1073, 1577, 3552, 1651, 3249, 4791, 2032, 1791]
[999, 5029, 2162, 3428, 1286, 4311, 961, 3051]
[1592, 2980, 1031, 3498, 3012, 3221, 2593, 2552]
[1366, 3499, 3469, 2963, 1914, 876, 1436, 3555]
[4301, 2435, 3547, 3127, 3256, 3192, 1679, 3131]
[273

 72%|███████▏  | 309/432 [00:05<00:00, 135.04it/s]

[1862, 4938, 2805, 1845, 3350, 2663, 3471, 1371]
[1136, 1026, 738, 4244, 2351, 2074, 3242, 3888]
[3273, 3173, 3354, 4008, 3670, 2122, 2476, 734]
[1318, 2059, 2763, 3344, 4283, 4432, 2447, 3364]
[3788, 3507, 3016, 2160, 455, 47, 305, 428]
[2503, 2719, 2454, 4785, 847, 4672, 410, 1037]
[3796, 1247, 3614, 1141, 2547, 4912, 2788, 345]
[2321, 4902, 1978, 1339, 3383, 698, 515, 2116]
[1234, 906, 4136, 1326, 1105, 1256, 1004, 4669]
[4032, 4769, 4694, 1928, 3749, 1142, 2437, 1668]
[4096, 86, 4361, 3869, 3973, 4143, 955, 1010]
[4493, 1802, 2075, 3954, 854, 2570, 4656, 1572]
[577, 2196, 4640, 324, 3686, 2034, 1424, 2523]
[3575, 4208, 4429, 4107, 4368, 4409, 5024, 2748]
[1122, 1645, 1646, 1699, 4594, 4802, 3142, 4314]
[4040, 3410, 4823, 1229, 4687, 2233, 4627, 2745]
[1789, 2536, 1455, 2429, 2835, 4887, 2136, 2812]
[569, 2920, 140, 3128, 633, 4648, 3557, 3154]
[4080, 1192, 1279, 3159, 4292, 603, 2245, 1717]
[5013, 1364, 3613, 2779, 894, 3005, 4121, 4663]
[1962, 1349, 1510, 675, 2776, 457, 4497, 175

 79%|███████▉  | 341/432 [00:05<00:00, 139.30it/s]

[4038, 2215, 3403, 329, 2175, 242, 1490, 567]
[4341, 1442, 1369, 3072, 2307, 180, 1429, 3153]
[2898, 2887, 2201, 1125, 2723, 4909, 1832, 1807]
[4266, 1425, 516, 691, 2803, 3762, 1162, 752]
[4069, 2061, 541, 2750, 1865, 1694, 2022, 1910]
[3288, 655, 4707, 4344, 48, 4215, 3991, 1975]
[3932, 2341, 4667, 4679, 3001, 4461, 4299, 488]
[4634, 3439, 427, 3578, 3282, 347, 3716, 1765]
[5011, 962, 3368, 3461, 539, 1060, 4919, 2746]
[37, 966, 4916, 3380, 748, 424, 1238, 528]
[120, 1263, 604, 1215, 1547, 1899, 2496, 3893]
[1043, 4399, 3385, 1917, 3081, 3472, 5017, 1961]
[2299, 3056, 1618, 4983, 4479, 1674, 760, 266]
[4738, 1145, 4470, 1163, 4967, 2542, 1565, 3232]
[3227, 1250, 2100, 3928, 240, 4391, 2004, 2339]
[1654, 1297, 4022, 2851, 2600, 1924, 1170, 1394]
[1450, 1261, 381, 850, 4333, 437, 720, 2983]
[1309, 2508, 4908, 21, 1504, 2461, 607, 4879]
[2390, 4033, 1604, 470, 4453, 4176, 2452, 3318]
[679, 793, 3179, 1958, 4189, 3809, 1725, 3585]
[1692, 1933, 3024, 1568, 2298, 722, 1984, 2007]
[4474, 35

 82%|████████▏ | 356/432 [00:05<00:00, 140.79it/s]

[334, 88, 3894, 3504, 4012, 2611, 1091, 1053]
[3857, 1207, 986, 4115, 1205, 1772, 2712, 3724]
[2701, 3981, 1329, 4419, 1926, 1702, 531, 3751]
[1891, 3582, 132, 2977, 2462, 2234, 4699, 4332]
[1848, 2453, 4952, 1764, 713, 3936, 1408, 2226]
[3017, 4613, 1684, 3825, 2670, 3702, 5022, 2349]
[3399, 2436, 3706, 819, 1520, 1705, 72, 2607]
[16, 920, 4592, 1041, 4512, 4328, 1642, 746]
[4597, 2782, 2655, 2159, 3592, 3683, 4248, 3084]
[2165, 557, 4660, 3867, 2484, 3970, 4082, 4211]
[1559, 611, 2475, 2216, 433, 1600, 2099, 3346]
[2969, 3097, 3838, 138, 4605, 912, 827, 3863]
[4866, 621, 666, 159, 648, 1783, 4327, 1166]
[145, 4726, 3174, 1352, 3618, 1087, 3124, 3522]
[2658, 3450, 3387, 3271, 3616, 1083, 2092, 953]
[3014, 13, 988, 705, 4841, 3028, 989, 2286]
[3026, 2531, 4021, 4770, 1965, 2199, 2791, 391]


 89%|████████▉ | 384/432 [00:06<00:00, 116.05it/s]

[4934, 3185, 279, 4422, 848, 2107, 1148, 2930]
[1747, 2338, 3558, 1866, 1182, 3311, 1341, 175]
[1797, 184, 3347, 1243, 1426, 2616, 1890, 4057]
[3500, 4027, 2326, 137, 2806, 1748, 65, 3324]
[629, 2772, 1955, 1546, 2372, 4139, 3553, 4312]
[3190, 856, 889, 3493, 2787, 396, 707, 2623]
[2710, 2204, 377, 2615, 4760, 1814, 730, 1903]
[3563, 2955, 553, 3548, 3272, 4545, 2569, 3437]
[4343, 1395, 365, 2832, 1388, 3988, 2939, 702]
[36, 4111, 4079, 3392, 4434, 4141, 373, 2498]
[2781, 2063, 506, 982, 4491, 52, 59, 1989]
[2945, 3156, 4999, 954, 2749, 2937, 4750, 4542]
[573, 1099, 4650, 2580, 2987, 2466, 808, 4533]
[1420, 1527, 163, 624, 3409, 3333, 4216, 2478]
[2718, 3965, 1711, 95, 1046, 4226, 1248, 2553]
[1670, 2025, 4600, 4878, 1815, 76, 4546, 4638]
[1997, 3166, 2047, 4530, 972, 4523, 3205, 4320]
[4745, 289, 1114, 2197, 3265, 1202, 2270, 3703]
[3834, 4923, 4805, 4544, 2568, 160, 1479, 3138]
[2890, 1531, 2440, 2148, 134, 1498, 4569, 2189]
[4305, 1062, 2283, 2115, 1631, 1459, 303, 4970]
[554, 1509,

 95%|█████████▌| 412/432 [00:06<00:00, 127.30it/s]

[3334, 4698, 2751, 3607, 4138, 4267, 820, 1732]
[3996, 1211, 1753, 2287, 2254, 2009, 1946, 3729]
[2288, 1216, 3497, 3045, 1044, 1949, 1317, 4930]
[465, 3490, 978, 4819, 4100, 1915, 4412, 2574]
[4717, 1117, 1993, 121, 1706, 1088, 1831, 4619]
[4814, 2556, 12, 1966, 66, 1804, 1397, 1629]
[1337, 3459, 3337, 2205, 1025, 1570, 1334, 3431]
[2442, 3979, 4526, 1754, 245, 2760, 336, 3475]
[2756, 641, 4473, 3438, 4023, 591, 3819, 2968]
[3604, 4431, 1180, 533, 4324, 2182, 3593, 5037]
[4632, 1647, 274, 316, 4894, 3044, 4921, 3654]
[948, 2043, 2305, 2520, 4101, 1161, 3187, 2168]
[2393, 719, 418, 595, 704, 147, 1476, 3013]
[2676, 429, 3139, 1981, 4180, 2003, 3783, 780]
[4271, 1034, 3363, 2543, 833, 964, 2689, 4268]
[3657, 3511, 4205, 2070, 3994, 1187, 2343, 474]
[3414, 216, 3876, 3712, 4394, 1817, 4016, 3366]
[1061, 3622, 1957, 4147, 2008, 2892, 2802, 56]
[51, 5020, 767, 1036, 3620, 4154, 1595, 3793]
[4685, 357, 1365, 1649, 3509, 2318, 2635, 1198]
[5018, 668, 3110, 2505, 378, 1303, 4325, 4019]
[4655,

100%|██████████| 432/432 [00:06<00:00, 65.23it/s] 

[1851, 774, 250, 1299, 346, 1695, 973, 478]
[14, 2893, 4895, 2207, 3007, 213, 863, 755]
[376, 3254, 3184, 2488, 174, 3251, 3074, 2584]
[2166, 4071, 3163, 2295, 2858, 662, 2320, 2731]
[1342, 882, 1074, 252, 2126, 3343, 1608, 3197]
[1494]





In [42]:
for i,  in enumerate(zip(train_loader,cycle(val_loader))):
    txt, segment, mask, img, img_aug, tgt, img_id= batch[0]
    txt, segment, mask, img, img_aug, tgt, img_id= batch[0]
    print(img_id)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [2]:
from datasets import load_dataset

In [3]:
dataset = load_dataset('oscar', 'unshuffled_deduplicated_it')

Reusing dataset oscar (/002/usuarios/ivonne.monter/.cache/huggingface/datasets/oscar/unshuffled_deduplicated_it/1.0.0/84838bd49d2295f62008383b05620571535451d84545037bb94d6f3501651df2)


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'text'],
        num_rows: 28522082
    })
})

In [None]:
from tqdm.auto import tqdm

text_data = []
file_count = 0

for sample in tqdm(dataset['train']):
    sample = sample['text'].replace('\n', '')
    text_data.append(sample)
    if len(text_data) == 10_000:
        # once we git the 10K mark, save to file
        with open(f'dataset/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
            fp.write('\n'.join(text_data))
        text_data = []
        file_count += 1
# after saving in 10K chunks, we will have ~2082 leftover samples, we save those now too
with open(f'dataset/text_{file_count}.txt', 'w', encoding='utf-8') as fp:
    fp.write('\n'.join(text_data))

  0%|          | 0/28522082 [00:00<?, ?it/s]