In [None]:
from datasets import IterableDataset, Dataset
import torch
from torch.utils.data import DataLoader
import pandas as pd

from data import load_labels, load_frames
from visualize import plot_outcome_distribution

from transformers import ViTImageProcessor, ViTForImageClassification

In [None]:
def generator(reduce_fps_factor, downscale_factor):
    settings = pd.read_csv(f'./data/experiments_settings.csv')
    for exp in ["a", "b", "c", "d", "e"]:
        print(f"Loading experiment {exp}")
        for pos in range(1, 10):
            print(f"Loading position {pos}")
            start_frame = int(settings[settings.Experiment == f'{exp}{pos}']['Starting Frame'].values[0]/reduce_fps_factor)
            end_frame = int(settings[settings.Experiment == f'{exp}{pos}']['End Frame Annotation'].values[0]/reduce_fps_factor)
            treatment = settings[settings.Experiment == f'{exp}{pos}']['Treatment'].values[0].astype(int)
            # load file .mkv
            frames = load_frames(exp, pos, 
                                 reduce_fps_factor=reduce_fps_factor, 
                                 downscale_factor=downscale_factor, 
                                 start_frame=start_frame, 
                                 end_frame=end_frame)
            # load annotations
            labels = load_labels(exp, pos, 
                                 reduce_fps_factor=reduce_fps_factor,
                                 start_frame=start_frame,
                                 end_frame=end_frame)
            for i in range(end_frame-start_frame):
                yield {
                    "experiment": exp,
                    "position": pos,
                    "frame": i,
                    "image": frames[i],
                    "treatment": treatment,
                    "outcome": labels[i,:],
                }
                

In [None]:
dataset = Dataset.from_generator(generator, gen_kwargs={"reduce_fps_factor": 10, "downscale_factor": 0.4})
dataset.save_to_disk("./data/train") 

In [None]:
# load dataset
train = Dataset.load_from_disk("./data/train")
train.set_format(type="torch", columns=["image", "treatment", "outcome"], output_all_columns=True)

In [None]:
plot_outcome_distribution(train, save=True)

In [None]:
image = train[0]['image']

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits
# print("Top 5 predicted labels with associated probabilities:")
# top_5 = torch.topk(logits, 5)
# probs = logits.softmax(-1)[0][top_5.indices][0]
# for i, (idx, prob) in enumerate(zip(top_5.indices[0], probs), 1):
#     print(f"    {i}. {model.config.id2label[idx.item()]}: {prob.item():.2%}")
# model predicts one of the 1000 ImageNet classes

inputs = processor(images=input, return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)

In [None]:
outputs.hidden_states[-1][:,0].shape

In [None]:
# create a new column 'emb1' using map-style function 
train = train.map(lambda x: {"emb1": x['image']*10}, batch_size=600, batched=True, num_proc=6)

In [None]:
iterable_dataset = IterableDataset.from_generator(generator, gen_kwargs={"reduce_fps_factor": 10, "downscale_factor": 0.5})
dataset = iterable_dataset.to_dataset()
dataset.save_to_disk("./data")

In [None]:
dataloader = DataLoader(dataset, batch_size=3, num_workers=0)
for batch in dataloader:
    print(batch)
    break

In [None]:
batch['image'].shape

In [None]:
# From a generator function
def my_generator(n, sources):
    for source in sources:
        for example_id_for_current_source in range(n):
            yield {"example_id": f"{source}_{example_id_for_current_source}"}

gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.n_shards  # 1024

In [None]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.pretrained_encoder = torchvision.models.resnet50(pretrained=True).eval().requires_grad_(False)
        self.linear = nn.Linear(1000, 10)

def forward(x):
    with torch.no_grad():
        emb1 = self.pretrained_encoder(x['image'])
    logits = self.linear(emb1)
    