In [4]:
import time,os,json
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torch.utils.data import Dataset

In [2]:
from Models.models import *
import wandb
import torch
import torch.nn 
import os 

import fasttext
import fasttext.util

config = dict(
    epochs=5,
    classes=10,
    kernels=[16, 32],
    batch_size=128,
    learning_rate=5e-3,
    dataset="MNIST")

class ConTextDataset(Dataset):
    def __init__(self, path_test_train_files, images_directory, train = True, transform = False):
        
        self.images_directory            = images_directory
        self.path_test_train_files       = path_test_train_files
        self.transform = transform

        if train:
            path_ = os.path.join(path_test_train_files,'train.txt')         #Path to the training set 
            path_ocr = os.path.join(path_test_train_files, 'ocr_train.txt') 
        else:
            path_ = os.path.join(path_test_train_files,'test.txt')          #Path to the test set
            path_ocr = os.path.join(path_test_train_files, 'ocr_test.txt')
        
        with open(path_, 'r') as file, open(path_ocr, 'r') as ocr_File:
            self.samples = [tuple(line.split()) for line in file]        #List of tuples.  Each tuple represents a file. A tuple contains the name and the label of the image.
            self.text    = [text.rstrip() for text in ocr_File]          #List of strings. Text of the ocr either for the train or test images.

        fasttext.util.download_model('en', if_exists='ignore')  # English
        self.fasttext = fasttext.load_model('cc.en.300.bin')
        self.dim_fasttext = self.fasttext.get_dimension()
        self.max_num_words = 64


    def __len__(self):
        return (len(self.samples))

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_directory, self.samples[idx][0]+'.jpg')
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)

        text = np.zeros((self.max_num_words, self.dim_fasttext))
        words = []
        if self.text[idx] != '0':
            for word in self.text[idx].split():
                if len(word) > 2: words.append(word)

        words = list(set(words))
        for i,w in enumerate(words):
            if i>=self.max_num_words: break
            text[i,:] = self.fasttext.get_word_vector(w)

        target = torch.tensor(int(self.samples[idx][1]))

        return image, text, target

def get_transform(train = True):
    input_size = 256
    if train:
        data_transforms_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(input_size),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        return data_transforms_train
    else:
        data_transforms_test = torchvision.transforms.Compose([
        torchvision.transforms.Resize(input_size),
        torchvision.transforms.CenterCrop(input_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        return data_transforms_test


def make_loader(dataset, batch_size):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         pin_memory=True, num_workers=2)
    return loader

def make(config, device="cuda"):
    # Make the data
    path_test_train_files = '/home/xnmaster/data/'
    images_directory      = '/home/xnmaster/data/JPEGImages/'

    train = ConTextDataset(path_test_train_files, images_directory, train = True,  transform = get_transform(train = True))
    test  = ConTextDataset(path_test_train_files, images_directory, train = False, transform = get_transform(train = False))

    train_loader  = make_loader(train, batch_size=config.batch_size)
    test_loader   = make_loader(test, batch_size=config.batch_size)

    model = ConTextTransformer()
    # Make the model
    #model = ConvNet(config.kernels, config.classes).to(device)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.learning_rate)
    
    return model, train_loader, test_loader, criterion, optimizer

ModuleNotFoundError: No module named 'fasttext'