<a href="https://colab.research.google.com/github/Alenush/dish_id_sirius/blob/Team-1/chefnet_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Загрузка данных

In [1]:
!pip install gdown



### Загрузка картинок

In [2]:
!gdown https://drive.google.com/uc?id=1rN0yvtlHkDbhHWRjp104FQuCyl9BAZuC
!unzip -qq AllRecipes_images.zip 
!rm AllRecipes_images.zip

Downloading...
From: https://drive.google.com/uc?id=1rN0yvtlHkDbhHWRjp104FQuCyl9BAZuC
To: /content/AllRecipes_images.zip
3.70GB [01:26, 42.9MB/s]


In [3]:
import os
print(f"Всего картинок - {len(os.listdir('./AllRecipes_images'))}")

Всего картинок - 187047


### Загрузка ингредиентов блюд

In [4]:
!gdown https://drive.google.com/uc?id=1rNb_CqMtA0lx-JzOxWhyM204SGQ0s5qJ
!mkdir AllRecipes_ingred
!unzip -qq db.zip -d AllRecipes_ingred
!rm db.zip

Downloading...
From: https://drive.google.com/uc?id=1rNb_CqMtA0lx-JzOxWhyM204SGQ0s5qJ
To: /content/db.zip
13.5MB [00:00, 22.0MB/s]


### MongoDB

In [5]:
!python -m pip install pymongo==3.7.2
!apt install mongodb

Collecting pymongo==3.7.2
[?25l  Downloading https://files.pythonhosted.org/packages/b1/45/5440555b901a8416196fbf2499c4678ef74de8080c007104107a8cfdda20/pymongo-3.7.2-cp36-cp36m-manylinux1_x86_64.whl (408kB)
[K     |████████████████████████████████| 409kB 2.8MB/s 
[?25hInstalling collected packages: pymongo
  Found existing installation: pymongo 3.10.1
    Uninstalling pymongo-3.10.1:
      Successfully uninstalled pymongo-3.10.1
Successfully installed pymongo-3.7.2
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-440
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  libpcap0.8 libstemmer0d libyaml-cpp0.5v5 mongo-tools mongodb-clients
  mongodb-server mongodb-server-core
The following NEW packages will be installed:
  libpcap0.8 libstemmer0d libyaml-cpp0.5v5 mongo-tools mongodb mongodb-clients
  mongodb-

In [6]:
!mongod --dbpath AllRecipes_ingred/db --fork --logpath /var/log/mongod.log

about to fork child process, waiting until server is ready for connections.
forked process: 641
child process started successfully, parent exiting


In [7]:
from pymongo import MongoClient
from bson import json_util
import json

In [8]:
db_client = MongoClient('localhost', 27017)
recipe_db = db_client['allrecipes']['recipe_data']
recipe_list = list(recipe_db.find())
print(f"Всего блюд - {len(recipe_list)}")

for recipe in recipe_list:
  if '_id' in recipe:
    del recipe['_id']

with open('recipes_raw.json', 'w') as f:
  f.write(json.dumps(recipe_list, indent=2))

Всего блюд - 31233


## Preprocessing

In [9]:
from PIL import Image
from tqdm.notebook import tqdm
import pickle

Проверка на битые картинки

In [10]:
path_to_img = 'AllRecipes_images/'
img_size = (250, 250)

bad_imgs_count = 0
for file_name in tqdm(os.listdir(path_to_img), desc='Checking images'):
  img = Image.open(path_to_img + file_name).convert('RGB')
  if img.size != img_size:
    print(f'Image {file_name} size - {img.size}')
    bad_imgs_count += 1

print(f'Total bad images found = {bad_imgs_count}')

HBox(children=(FloatProgress(value=0.0, description='Checking images', max=187047.0, style=ProgressStyle(descr…


Total bad images found = 0


Удаление дубликатов рецептов

In [11]:
with open('recipes_raw.json', 'rb') as f:
  recipe_list = json.loads(f.read())

In [12]:
unique_ids = set()

tmp = []
for recipe in recipe_list:
  if recipe['id'] not in unique_ids:
    tmp.append(recipe)
    unique_ids.add(recipe['id'])

recipe_list = tmp
print((f'Рецептов после удаления дубликатов - {len(recipe_list)}'))

Рецептов после удаления дубликатов - 27827


In [13]:
with open('recipes_raw.json', 'w') as f:
  f.write(json.dumps(recipe_list))

In [14]:
empty_recipes_count = 0 
for recipe in recipe_list:
  if len(recipe['ingred_list']) == 0:
    empty_recipes_count += 1

print(f'Рецептов с пустым списком ингред-ов - {empty_recipes_count}')

Рецептов с пустым списком ингред-ов - 0


Препроцессинг ингредиентов

In [15]:
import html
from html.parser import HTMLParser
import re
import unicodedata

REPLACEMENTS = {
    u'\x91':"'", u'\x92':"'", u'\x93':'"', u'\x94':'"', u'\xa9':'',
    u'\xba': ' degrees ', u'\xbc':' 1/4', u'\xbd':' 1/2', u'\xbe':' 3/4',
    u'\xd7':'x', u'\xae': '',
    '\\u00bd':' 1/2', '\\u00bc':' 1/4', '\\u00be':' 3/4',
    u'\\u2153':' 1/3', '\\u00bd':' 1/2', '\\u00bc':' 1/4', '\\u00be':' 3/4',
    '\\u2154':' 2/3', '\\u215b':' 1/8', '\\u215c':' 3/8', '\\u215d':' 5/8',
    '\\u215e':' 7/8', '\\u2155':' 1/5', '\\u2156':' 2/5', '\\u2157':' 3/5',
    '\\u2158':' 4/5', '\\u2159':' 1/6', '\\u215a':' 5/6', '\\u2014':'-',
    '\\u0131':'1', '\\u2122':'', '\\u2019':"'", '\\u2013':'-', '\\u2044':'/',
    '\\u201c':'\\"', '\\u2018':"'", '\\u201d':'\\"', '\\u2033': '\\"',
    '\\u2026': '...', '\\u2022': '', '\\u2028': ' ', '\\u02da': ' degrees ',
    '\\uf04a': '', u'\xb0': ' degrees ', '\\u0301': '', '\\u2070': ' degrees ',
    '\\u0302': '', '\\uf0b0': ''
}

parser = HTMLParser()
def prepro_txt(text):
    import urllib

    text = html.unescape(text)

    for unichar, replacement in REPLACEMENTS.items():
      if unichar in text:
        text = text.replace(unichar, replacement)
    text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore')

    try:
        text = urllib.parse.unquote(text.decode('ascii'))
    except UnicodeDecodeError:
        pass # if there's an errant %, unquoting will yield an invalid char

    # some extra tokenization
    text = ' - '.join(text.split('-'))
    text = ' & '.join(text.split('&'))

    text = re.sub(r'\\[nt]', ' ', text) # remove over-escaped line breaks and tabs
    text = re.sub(r'\b([^\d\s]+)/(.*)\b', r'\1 / \2', text) # split non-fractions
    text = re.sub(r'\b(.*)/([^\d\s]+)\b', r'\1 / \2', text) # e.g. 350 deg/gas mark
    text = re.sub(r'\s+', ' ', text) # remove extra whitespace

    return text.strip()

In [16]:
for recipe in tqdm(recipe_list, desc='Preprocess ingredients'):
  ingred_list = recipe['ingred_list']
  new_ingred_list = []
  for ingred in ingred_list:
    new_ingred_list.append(prepro_txt(ingred))
  recipe['ingred_list'] = new_ingred_list

HBox(children=(FloatProgress(value=0.0, description='Preprocess ingredients', max=27827.0, style=ProgressStyle…




In [17]:
from string import ascii_lowercase
def replace_units(s):
  tokens = s.replace('to taste', '').split()
  units = ['ounce', 'ounces', 'cups', 'cup', 'teaspoon', 'tablespoon', 'tablespoons', 'teaspoons', 
           'c', 'g', 'v', 'tbsp', 'x', 'ml', 'lb', 'tbs', 'oz', 'pkg', 'large', 'small', 'tsp', 'inch', 
           'grams', 'quarts', 'lbs', 'can', 'cube', 'whole', 'or', 'pieces', 'piece', 'chopped', 
           'shredded', 'diced', 'fresh', 'crushed', 'minced']
  tokens_new = []
  for t in tokens:
      if t not in units:
          tokens_new.append(t)
  
  return ' '.join(tokens_new)

def cleanup_ingredient_list(l):
    l = [replace_units(''.join([char for char in x.lower() if char in ascii_lowercase + ' ']).strip()) for x in l]
    return l

In [18]:
for recipe in tqdm(recipe_list, desc='Preprocess ingredients'):
  recipe['ingred_list'] = cleanup_ingredient_list(recipe['ingred_list'])

HBox(children=(FloatProgress(value=0.0, description='Preprocess ingredients', max=27827.0, style=ProgressStyle…




In [28]:
import csv
from nltk import ngrams
from collections import Counter

In [23]:
import string
clean_ingredients = set()

with open('food.csv', 'r') as f:
  reader = csv.reader(f)
  next(reader)
  for row in tqdm(reader):
    ingr = row[2].lower()
    ingr = re.sub(r',', '', ingr)
    ingr_split = [word for word in ingr.split() if word.isalpha()]
    if len(ingr_split) > 5:
      ingr_split = ingr_split[:5]
    bigrams = []
    for i, ingr_tuple in enumerate(list(iter(ngrams(ingr_split, 2)))):
      bigrams.append(' '.join(ingr_tuple))
    ingr_split.reverse()
    for ingr_tuple in list(iter(ngrams(ingr_split, 2))):
      bigrams.append(' '.join(ingr_tuple))
    unigrams = ingr_split
    candidates = [*bigrams, *unigrams]
    
    clean_ingredients.update(candidates)

clean_ingredients = list(clean_ingredients)
print(len(clean_ingredients))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


3713


In [24]:
from nltk import ngrams

def select_n_max_length_ingr(l, n=2):
    l.sort(key=lambda x: -len(x))
    return l[:n]

def select_all(l):
    return l

def select_percentage_and_min(l, percentage=0.65, min_count=2):
    return select_n_max_length_ingr(l, n=max(min_count, round(len(l)*percentage)))

def clean_ingredient_list(l, clean_ingredients):
    recipe = []

    for ingr in l:
      bigrams = []
      for i, ingr_tuple in enumerate(list(iter(ngrams(ingr.split(), 2)))):
        bigrams.append(' '.join(ingr_tuple))
      bigr_found = False
      for ingr in bigrams:
        if ingr in clean_ingredients:
          recipe.append(ingr)
          bigr_found = True
          break
      if bigr_found:
        continue
      unigrams = ingr.split()
      for ingr in unigrams:
        if ingr in clean_ingredients:
          recipe.append(ingr)
          break

    return recipe

In [25]:
# clean_ingredients.remove('all purpose')
# clean_ingredients.remove('skinless boneless')
# clean_ingredients.remove('boneless')
# clean_ingredients.remove('skinless')
# clean_ingredients.remove('prepared')

In [26]:
for recipe in tqdm(recipe_list):
  recipe['ingred_list'] = clean_ingredient_list(recipe['ingred_list'], clean_ingredients)

HBox(children=(FloatProgress(value=0.0, max=27827.0), HTML(value='')))






In [None]:
find_n_ingredients = 3000
select_top_n_ingredients = 1000

ingredient_preselection = []
unique_ingredient_count = 0

for recipe in tqdm(recipe_list):
    new_ingred_list = recipe['ingred_list']
    ingredient_preselection.extend(new_ingred_list)
        
    unique_ingredient_count = len(set(ingredient_preselection))
    if unique_ingredient_count > find_n_ingredients:
        break

print(f'Found {unique_ingredient_count} unique ingredients') 
        
c = Counter(ingredient_preselection)
clean_ingredients = list([i for i, n in c.most_common()][:select_top_n_ingredients])

In [59]:
recipe_list = [recipe for recipe in recipe_list if len(recipe['ingred_list']) > 0]

In [60]:
with open('recipes.json', 'w') as f:
  f.write(json.dumps(recipe_list, indent=2))

## Обучение

In [49]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from collections import defaultdict
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [33]:
id2word = defaultdict()
id2word[0] = '<pad>'
id2word[1] = '<start>'
id2word[2] = '<end>'
id2word[3] = '<unk>'
for i, ingr in enumerate(clean_ingredients):
  id2word[i+4] = ingr

word2id = {v:k for k,v in id2word.items()}

In [34]:
def ids2words(word_ids):
  return [id2word[id] for id in word_ids]

def words2ids(words):
  return [word2id[word] for word in words]

In [79]:
class RecipesDataset(Dataset):
  def __init__(self, recipes_json, img_folder, transform=None):
    super().__init__()

    with open(recipes_json, 'r') as f:
      recipes = json.load(f)
    self.recipes = defaultdict(dict)
    for recipe in recipes:
      self.recipes[recipe['id']] = recipe

    self.img_folder = img_folder
    self.imgs = os.listdir(self.img_folder)
    self.imgs = [img for img in self.imgs if img.split('_')[0] in self.recipes]
    
    self.transform = transform

  def __getitem__(self, index):
    img_name = self.imgs[index]
    img = Image.open(self.img_folder + img_name).convert('RGB')
    if self.transform is not None:
      img = self.transform(img)
    
    recipe_id = img_name.split('_')[0]
    ingred_list = self.recipes[recipe_id]['ingred_list']
    return img, torch.LongTensor(words2ids(ingred_list))

  def __len__(self):
    return len(self.imgs)

In [36]:
from torch.nn.utils.rnn import pad_sequence

def pad_collate(data):
  # отсортируем по длине описания
  data.sort(key=lambda x: len(x[1]), reverse=True)
  images, captions = zip(*data)
  
  # составим 4D тензор изображений из кортежа 3D тензоров
  # images: (batch_size, channels(rgb), width, height)
  images = torch.stack(images, 0)

  # составим 2D тензор описаний из кортежа 1D тензоров
  # дополним каждое описание символом <pad> так, чтобы у всех описаний совпадали длины
  lengths = [len(cap) for cap in captions]
  targets = torch.LongTensor(np.zeros((len(captions), max(lengths))))
  for i, cap in enumerate(captions):
    end = lengths[i]
    targets[i, :end] = cap[:end]  

  return images, targets, lengths

In [80]:
transform = transforms.Compose([
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
  transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

train_data = RecipesDataset('recipes.json', 'AllRecipes_images/', transform)

batch_size = 256
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                          num_workers=4, collate_fn=pad_collate, shuffle=True)

In [38]:
class Encoder(nn.Module):
  def __init__(self, encoder):
    super().__init__()

    for param in encoder.parameters():
        param.requires_grad = False

    modules = list(encoder.children())[:-1]
    self.encoder = nn.Sequential(*modules)
  
  def forward(self, images):
    with torch.no_grad():
      features = self.encoder(images)
    features = features.view(features.size(0), -1)
    return features

In [39]:
encoder = Encoder(models.resnet50(pretrained=True)).to(device)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




In [128]:
class Decoder(nn.Module):
    def __init__(self, encod_size, embed_size, hidden_size, vocab_size, num_layers=2, dropout=0):
      super().__init__()
      self.encod_feat = nn.Linear(encod_size, embed_size)
      #self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
      self.embedding = nn.Embedding(vocab_size, embed_size)
      self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True,
                         dropout=dropout, bidirectional=True)
      self.linear = nn.Linear(hidden_size * 2, vocab_size)
  
    # обучение, используются истинные описания
    def forward(self, features, captions, lengths):
      embed = self.embedding(captions)
      features = self.encod_feat(features)
      #featues = self.bn(features)
      embed = torch.cat((features.unsqueeze(1), embed), 1)
      packed = pack_padded_sequence(embed, lengths, batch_first=True)
      outputs, _ = self.rnn(packed)
      outputs = self.linear(outputs[0])
      return outputs

    # проверка, истинные описания не используются
    def sample(self, features, max_len, states=None):
      id_preds = list()
      features = self.encod_feat(features)
      #features = self.bn(features)
      inputs = features.unsqueeze(1)
      for i in range(max_len):
        rnn_outputs, states = self.rnn(inputs, states)
        lin_outputs = self.linear(rnn_outputs.squeeze(1))
        _, pred_id = lin_outputs.max(1)
        id_preds.append(pred_id)
        inputs = self.embedding(pred_id).unsqueeze(1)
      id_preds = torch.stack(id_preds, 1)
      return id_preds

In [129]:
encod_size = 2048
embed_size = 512
hidden_size = 512
vocab_size = len(word2id)

decoder = Decoder(encod_size, embed_size, hidden_size, vocab_size).to(device)

In [136]:
def train(encoder, decoder, criterion, optimizer, num_epochs,
          epoch_losses, log_step=100):

  batch_losses = defaultdict(list)
  for epoch in range(1, num_epochs + 1):
    # режим обучения

    print(f'Эпоха {epoch}')

    running_loss = 0.0
    
    i = 1
    for images, tokens, lengths in tqdm(train_loader, desc=f'Обучение'):
      # Перенесем изобр-я и описания на GPU, составим вектор истинных описаний
      images = images.to(device)
      tokens = tokens.to(device)
      targets = pack_padded_sequence(tokens, lengths, batch_first=True)[0]

      # Проведем изобр-я через модель, получим описания (в виде векторов id), сгенерир. моделью
      features = encoder(images)
      outputs = decoder(features, tokens[:, :-1], lengths)
      
      # посчитать лосс на батче
      loss = criterion(outputs, targets)
      
      # посчитать градиенты
      loss.backward()
      
      # обновить параметры модели
      optimizer.step()
      
      # обнулить градиенты
      optimizer.zero_grad()

      running_loss += loss.item()

      if i % log_step == 0:
        batch_loss = loss.item()
        print(f'Batch {i} loss:\t{batch_loss:.4f}')
        batch_losses['train_nll'].append(batch_loss)

      i+=1

    torch.save(decoder.state_dict(), f'weights/{i}_epoch.pth')
    

    train_loss = running_loss / len(train_loader)
    print(f'Epoch train loss:\t{train_loss:.4f}')
    epoch_losses['train_nll'].append(train_loss)
    
  return epoch_losses

In [None]:
for p in decoder.rnn.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -5, 5))

criterion = nn.CrossEntropyLoss()
torch.manual_seed(42)
optimizer = optim.Adam(params=decoder.parameters(), lr=1e-3)
epoch_losses = defaultdict(list)

encoder.train()
decoder.train()

torch.manual_seed(42)
train(encoder, decoder, criterion, optimizer, num_epochs=5, epoch_losses=epoch_losses)

Эпоха 1


HBox(children=(FloatProgress(value=0.0, description='Обучение', max=728.0, style=ProgressStyle(description_wid…

Batch 100 loss:	0.0562
Batch 200 loss:	0.0630
Batch 300 loss:	0.0635
Batch 400 loss:	0.0482
Batch 500 loss:	0.0612
Batch 600 loss:	0.0444


In [117]:
torch.save(decoder.state_dict(), f'weights/5_epochs.pth')

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(epoch_losses['train_nll'])
plt.title('Лосс на обучении')
plt.show()

Посмотрим на результат

In [130]:
decoder.load_state_dict(torch.load(f'weights/5_epochs.pth'))

<All keys matched successfully>

In [None]:
from random import randint

decoder.eval()
val_images = os.listdir('AllRecipes_images')
with torch.no_grad():
  img = Image.open('AllRecipes_images/' + val_images[randint(0, len(val_images)-1)])
  plt.figure(figsize=(12,8))
  plt.imshow(img)
  img = transform(img).to(device).unsqueeze(0)
  ingred_pred = ids2words(decoder.sample(encoder(img), 7)[0].cpu().numpy())
  print(ingred_pred)