In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!cp  /content/drive/MyDrive/code/env.py /content/
!cp  /content/drive/MyDrive/code/dataloader.py /content/
!cp  /content/drive/MyDrive/code/get_embeddings.py /content/

In [1]:
# !cp /content/drive/MyDrive/random_data.zip /content/correct_data.zip
!cp /content/drive/MyDrive/selected_frames.zip /content/selected_frames.zip

In [2]:
import shutil
import os
# print(os.getcwd())
shutil.unpack_archive('selected_frames.zip', 'selected_frames')

In [4]:
import os
import torch
import pickle
import random
import argparse
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from env import set_seed
from dataloader import Dataloader
from torch.utils.data import TensorDataset, DataLoader, Dataset

import tensorflow as tf
import tensorflow_hub as hub

# from official.projects.movinet.modeling import movinet
# from official.projects.movinet.modeling import movinet_model

import warnings
warnings.filterwarnings("ignore")

In [4]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')

In [15]:
args = pd.DataFrame()
args.batch_size = 2
args.logit_dir = "/content/drive/MyDrive/logits"
args.is_random = True
args.num_classes = 5
args.num_frames = 5
args.seed = 42
args.model_name = "movinet"
args.dataset_type = "noise"
args.video_dir_path = "/content/extracted_correct_data/random_data"

In [6]:
set_seed(args.seed)

if not os.path.exists(args.logit_dir):
    os.makedirs(args.logit_dir)

In [13]:
print("\n######################## Loading Model ########################\n")
hub_url = "https://tfhub.dev/tensorflow/movinet/a2/base/kinetics-600/classification/3"

encoder = hub.KerasLayer(hub_url, trainable=False)
inputs = tf.keras.layers.Input(shape=[None, None, None, 3], dtype=tf.float32, name='image')

# [batch_size, 600]
outputs = encoder(dict(image=inputs))


######################## Loading Model ########################



In [14]:
model = tf.keras.Model(inputs, outputs, name='movinet')

In [7]:
print("\n######################## Loading Data ########################\n")
dataloader = Dataloader(args.video_dir_path, num_classes = args.num_classes)

tensor_dataset = tf.data.Dataset.from_tensor_slices((dataloader.instances, dataloader.labels))
batched_dataset = tensor_dataset.batch(args.batch_size)


######################## Loading Data ########################



100%|██████████| 1000/1000 [00:14<00:00, 70.88it/s]


In [17]:
def get_logits(model, model_name, dataloader, device='/device:GPU:0'):
    
    logits = []
    labels = []

    for idx, (video_frames, label) in tqdm(enumerate(dataloader), position=0, leave=True):
        if idx == 5: break
        shape = video_frames.shape
        with tf.device(device):
            outputs = tf.stop_gradient(model(video_frames))
        del video_frames
        logits.extend(outputs)
        labels.extend(label)

    if model_name != "movinet":
        logits = torch.from_numpy(np.array(logits))
    else:
        logits = tf.stack(logits)

    labels = tf.stack(labels)

    return logits, labels

In [None]:
print("\n######################## Getting Logits ########################\n")
logits, labels = get_logits(model, args.model_name, batched_dataset)
combined = zip(logits, labels)
pickle.dump(combined, open(os.path.join(args.logit_dir, args.model_name + "_tf_selected_frames_" + args.dataset_type + ".pkl"), "wb"))


######################## Getting Logits ########################



In [12]:
print(logits[0][0])

tf.Tensor(0.3808917, shape=(), dtype=float32)


In [13]:
tf.argmax(logits[0], -1)

<tf.Tensor: shape=(), dtype=int64, numpy=461>

In [17]:
from torch.nn.functional import softmax
import torch
sm = logits[0].numpy()

In [18]:
sm.shape

(600,)

In [23]:
sm1 = softmax(torch.from_numpy(sm))

In [25]:
sm1[461]

tensor(0.9233)

In [12]:
for i in batched_dataset:
  print(i[0].shape)
  break

(16, 5, 224, 224, 3)
