In [3]:
import os
import numpy as np 
import pandas as pd 
from datetime import datetime
import time
import random
from tqdm.autonotebook import tqdm

#Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler

#sklearn
from sklearn.model_selection import StratifiedKFold

################# DETR FUCNTIONS FOR LOSS######################## 
import sys
sys.path.append('./detr-fish/')

from models.matcher import HungarianMatcher
from models.detr import SetCriterion
#################################################################

import matplotlib.pyplot as plt

#Glob
from glob import glob

  from tqdm.autonotebook import tqdm


In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def pytorch_init_janus_gpu():
    device_id = 1
    torch.cuda.set_device(device_id)
    
    # Sanity checks
    assert torch.cuda.current_device() == 1, 'Using wrong GPU'
    assert torch.cuda.device_count() == 2, 'Cannot find both GPUs'
    assert torch.cuda.get_device_name(0) == 'GeForce RTX 2080 Ti', 'Wrong GPU name'
    assert torch.cuda.is_available() == True, 'GPU not available'
    return torch.device('cuda', device_id)

In [5]:
seed = 42069
seed_everything(seed)

try:
    device = pytorch_init_janus_gpu()
except AssertionError as e:
    print('GPU could not initialize, got error:', e)
    device = torch.device('cpu')

print('Using device:', device)

Using device: cuda:1


In [6]:
!pwd

/app


In [7]:
torch.hub.DEFAULT_CACHE_DIR = '.'

In [10]:
class DETRModel(nn.Module):
    def __init__(self,num_classes,num_queries):
        super(DETRModel,self).__init__()
        self.num_classes = num_classes
        self.num_queries = num_queries
        
        self.model = torch.hub.load('torch/hub/facebookresearch_detr_master', 'detr_resnet50', source='local', pretrained=True)
        self.in_features = self.model.class_embed.in_features
        
        self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes)
        self.model.num_queries = self.num_queries
        
    def forward(self,images):
        return self.model(images)

In [11]:
model = DETRModel(10, 100)

In [20]:
model = torch.hub.load('torch/hub/facebookresearch_detr_master', 'detr_resnet50', source='local', pretrained=True)

In [32]:
samples = torch.randint(0, 255, size=(2,3,512,512), dtype=torch.float)

In [33]:
model(samples)

{'pred_logits': tensor([[[-16.6211,  -2.0212, -11.4596,  ...,  -8.6111,  -4.0576,  10.3388],
          [-17.0190,  -2.3417, -10.0571,  ...,  -8.5198,  -5.5366,  10.6736],
          [-17.1467,  -3.0588,  -9.2692,  ...,  -7.5721,  -4.7911,   9.6887],
          ...,
          [-17.5096,  -3.7635, -11.3821,  ..., -10.0201,  -7.1388,  10.8480],
          [-17.4604,  -3.4431, -11.2893,  ...,  -8.4127,  -6.1204,  10.4343],
          [-17.8071,  -3.6522, -12.2925,  ...,  -7.1949,  -5.0404,  10.7811]],
 
         [[-16.7074,  -3.5336, -10.2962,  ...,  -8.4931,  -6.2095,   9.9806],
          [-18.1759,  -3.2416, -13.1338,  ...,  -7.0044,  -5.0956,  11.7741],
          [-16.5708,  -4.2240,  -9.6216,  ...,  -9.7981,  -4.6481,  10.0605],
          ...,
          [-16.7176,  -4.8037, -10.3876,  ...,  -9.5806,  -6.6510,   9.8639],
          [-17.6222,  -5.3594, -12.2140,  ..., -10.5969,  -7.1044,  10.4724],
          [-17.1385,  -4.5561, -11.8686,  ...,  -8.1604,  -6.1928,  10.7749]]],
        grad_f