In [51]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from data_loader import get_loader
from models import VqaModel


device = torch.device('cpu')

In [None]:
def infer(model_fname):
    
    data_loader = get_loader(
        input_dir='/home/lilydpjja/basic_vqa/demo/',
        input_vqa_train='demo.npy',
        input_vqa_valid='demo.npy',
        max_qst_length=30,
        max_num_ans=10,
        batch_size=1,
        num_workers=8)
    
    qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
    ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
    ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx
    
    model = VqaModel(
        embed_size=1024,
        qst_vocab_size=qst_vocab_size,
        ans_vocab_size=ans_vocab_size,
        word_embed_size=300,
        num_layers=2,
        hidden_size=512).to(device)
    
    model = torch.load(os.path.join('/home/lilydpjja/basic_vqa/models/', model_fname))
    model.eval()
    
    output = model(image, question)
    probs, indices = torch.sort(F.softmax(output.squeeze(), dim=0), dim=0, descending=True)
    probs_top5 = probs.tolist()[:5]
    answers_top5 = [idx2ans[idx] for idx in indices.tolist()[:5]]
    
    return probs_top5, answers_top5

In [None]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from data_loader import get_loader
from models import VqaModel


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def main(args):

    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)
    
    

    qst_vocab_size = dataset.qst_vocab.vocab_size
    ans_vocab_size = dataset.ans_vocab.vocab_size
    ans_unk_idx = dataset.ans_vocab.unk2idx

    model = VqaModel(
        embed_size=args.embed_size,
        qst_vocab_size=qst_vocab_size,
        ans_vocab_size=ans_vocab_size,
        word_embed_size=args.word_embed_size,
        num_layers=args.num_layers,
        hidden_size=args.hidden_size).to(device)

    criterion = nn.CrossEntropyLoss()

    params = list(model.img_encoder.fc.parameters()) \
        + list(model.qst_encoder.parameters()) \
        + list(model.fc1.parameters()) \
        + list(model.fc2.parameters())

    optimizer = optim.Adam(params, lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

    for epoch in range(args.num_epochs):

        for phase in ['train', 'valid']:

            running_loss = 0.0
            running_corr_exp1 = 0
            running_corr_exp2 = 0
            batch_step_size = len(data_loader[phase].dataset) / args.batch_size

            if phase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()

            for batch_idx, batch_sample in enumerate(data_loader[phase]):

                image = batch_sample['image'].to(device)
                question = batch_sample['question'].to(device)
                label = batch_sample['answer_label'].to(device)
                multi_choice = batch_sample['answer_multi_choice']  # not tensor, list.

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    output = model(image, question)      # [batch_size, ans_vocab_size=1000]
                    _, pred_exp1 = torch.max(output, 1)  # [batch_size]
                    _, pred_exp2 = torch.max(output, 1)  # [batch_size]
                    loss = criterion(output, label)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Evaluation metric of 'multiple choice'
                # Exp1: our model prediction to '<unk>' IS accepted as the answer.
                # Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
                pred_exp2[pred_exp2 == ans_unk_idx] = -9999
                running_loss += loss.item()
                running_corr_exp1 += torch.stack([(ans == pred_exp1.cpu()) for ans in multi_choice]).any(dim=0).sum()
                running_corr_exp2 += torch.stack([(ans == pred_exp2.cpu()) for ans in multi_choice]).any(dim=0).sum()

                # Print the average loss in a mini-batch.
                if batch_idx % 100 == 0:
                    print('| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
                          .format(phase.upper(), epoch+1, args.num_epochs, batch_idx, int(batch_step_size), loss.item()))

            # Print the average loss and accuracy in an epoch.
            epoch_loss = running_loss / batch_step_size
            epoch_acc_exp1 = running_corr_exp1.double() / len(data_loader[phase].dataset)      # multiple choice
            epoch_acc_exp2 = running_corr_exp2.double() / len(data_loader[phase].dataset)      # multiple choice

            print('| {} SET | Epoch [{:02d}/{:02d}], Loss: {:.4f}, Acc(Exp1): {:.4f}, Acc(Exp2): {:.4f} \n'
                  .format(phase.upper(), epoch+1, args.num_epochs, epoch_loss, epoch_acc_exp1, epoch_acc_exp2))

            # Log the loss and accuracy in an epoch.
            with open(os.path.join(args.log_dir, '{}-log-epoch-{:02}.txt')
                      .format(phase, epoch+1), 'w') as f:
                f.write(str(epoch+1) + '\t'
                        + str(epoch_loss) + '\t'
                        + str(epoch_acc_exp1.item()) + '\t'
                        + str(epoch_acc_exp2.item()))

        # Save the model check points.
        if (epoch+1) % args.save_step == 0:
            torch.save({'epoch': epoch+1, 'state_dict': model.state_dict()},
                       os.path.join(args.model_dir, 'model-epoch-{:02d}.ckpt'.format(epoch+1)))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--input_dir', type=str, default='./datasets',
                        help='input directory for visual question answering.')

    parser.add_argument('--log_dir', type=str, default='./logs',
                        help='directory for logs.')

    parser.add_argument('--model_dir', type=str, default='./models',
                        help='directory for saved models.')

    parser.add_argument('--max_qst_length', type=int, default=30,
                        help='maximum length of question. \
                              the length in the VQA dataset = 26.')

    parser.add_argument('--max_num_ans', type=int, default=10,
                        help='maximum number of answers.')

    parser.add_argument('--embed_size', type=int, default=1024,
                        help='embedding size of feature vector \
                              for both image and question.')

    parser.add_argument('--word_embed_size', type=int, default=300,
                        help='embedding size of word \
                              used for the input in the LSTM.')

    parser.add_argument('--num_layers', type=int, default=2,
                        help='number of layers of the RNN(LSTM).')

    parser.add_argument('--hidden_size', type=int, default=512,
                        help='hidden_size in the LSTM.')

    parser.add_argument('--learning_rate', type=float, default=0.001,
                        help='learning rate for training.')

    parser.add_argument('--step_size', type=int, default=10,
                        help='period of learning rate decay.')

    parser.add_argument('--gamma', type=float, default=0.1,
                        help='multiplicative factor of learning rate decay.')

    parser.add_argument('--num_epochs', type=int, default=30,
                        help='number of epochs.')

    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size.')

    parser.add_argument('--num_workers', type=int, default=8,
                        help='number of processes working on cpu.')

    parser.add_argument('--save_step', type=int, default=1,
                        help='save step of model.')

    args = parser.parse_args()

    main(args)


In [3]:
import numpy as np

In [4]:
train = np.load('./datasets/train.npy')

In [6]:
type(train)

numpy.ndarray

In [10]:
train[0]['image_path']

'/home/lilydpjja/basic_vqa/datasets/Resized_Images/train2014/COCO_train2014_000000458752.jpg'

In [17]:
train[0]

{'image_name': 'COCO_train2014_000000458752',
 'image_path': '/home/lilydpjja/basic_vqa/datasets/Resized_Images/train2014/COCO_train2014_000000458752.jpg',
 'question_id': 458752000,
 'question_str': 'What is this photo taken looking through?',
 'question_tokens': ['what',
  'is',
  'this',
  'photo',
  'taken',
  'looking',
  'through',
  '?'],
 'all_answers': ['net',
  'net',
  'net',
  'netting',
  'net',
  'net',
  'mesh',
  'net',
  'net',
  'net'],
 'valid_answers': ['net', 'net', 'net', 'net', 'net', 'net', 'net', 'net']}

In [18]:
demo = np.load('./demo/demo.npy')

In [20]:
demo[0]

{'image_name': 'COCO_train2014_000000458752',
 'image_path': '/home/lilydpjja/basic_vqa/upload/last.jpg',
 'question_id': 458752000,
 'question_str': 'what is the animal?',
 'question_tokens': ['what', 'is', 'the', 'animal'],
 'all_answers': ['net',
  'net',
  'net',
  'netting',
  'net',
  'net',
  'mesh',
  'net',
  'net',
  'net'],
 'valid_answers': ['net', 'net', 'net', 'net', 'net', 'net', 'net', 'net']}

In [21]:
from data_loader import VqaDataset, get_loader

{'image': tensor([[[-1.9980, -1.9980, -2.0152,  ..., -1.3815, -1.3473, -1.3302],
          [-1.9980, -1.9980, -2.0152,  ..., -1.3473, -1.3302, -1.3130],
          [-1.9809, -1.9980, -2.0152,  ..., -1.5528, -1.5528, -1.5528],
          ...,
          [ 0.0741,  0.1083,  0.1083,  ..., -1.5357, -1.5014, -1.4843],
          [ 0.1254,  0.1426,  0.1254,  ..., -1.5357, -1.5185, -1.5014],
          [ 0.0569,  0.0741,  0.0741,  ..., -1.5528, -1.5357, -1.5014]],
 
         [[-1.9132, -1.9132, -1.9307,  ..., -1.2654, -1.2304, -1.2129],
          [-1.9132, -1.9132, -1.9307,  ..., -1.2304, -1.2129, -1.1954],
          [-1.8957, -1.9132, -1.9307,  ..., -1.4405, -1.4405, -1.4405],
          ...,
          [ 0.2402,  0.2227,  0.2227,  ..., -1.6856, -1.6506, -1.6331],
          [ 0.3102,  0.3277,  0.3102,  ..., -1.6856, -1.6681, -1.6506],
          [ 0.2402,  0.2577,  0.2577,  ..., -1.7031, -1.6856, -1.6506]],
 
         [[-1.6824, -1.6824, -1.6999,  ..., -1.1421, -1.1073, -1.0898],
          [-1.6824,

In [29]:
from torchvision import transforms

In [31]:
data_loader = get_loader(
        input_dir='/home/lilydpjja/basic_vqa/demo/',
        input_vqa_train='demo.npy',
        input_vqa_valid='demo.npy',
        max_qst_length=30,
        max_num_ans=10,
        batch_size=1,
        num_workers=8)

In [24]:
for batch_sample in data_loader:
    print(batch_sample)

train
valid


In [5]:
len(train)

443757

In [1]:
ans2idx = [line.strip() for line in open('./demo/vocab_questions.txt', 'r').readlines()]

In [3]:
for ans in ans2idx:
    print(ans)

<pad>
<unk>
!
!"
!"?
!."?
"
" -
"&"
"'
")
"+"
",
"?
"??
"@"
#
#?
$
%
&
'
''
'?
'^'
(
("
)
)?
*
+
+-
+?
,
, "
, '
, '...
,"
,"?
-
--
-----
.
."
."?
.'
.'?
.,
..
...
... (
..."
...?
.?
/
/?
0
00
000
01
02
026m
03
05
06
07
1
10
100
100k
100m
100th
101
106
107
11
1100
110v
1110
114
11th
12
120
1201
1208
12th
13
130th
1313
138
14
141
144
14th
15
150
1560
1567
16
1600
16318
164
16e
17
172
176
17th
18
1800
1800s
1802
182
1876
189
1899
18th
19
1900
192
1920
1920s
1929
1930
1932
1940
1950
1950s
1960
1960s
1967
1970
1970s
1974
1975
1980
1980s
1985
1989
1990
1990s
1993
1999
19th
1pm
1st
2
20
200
2000
2001
2002
2005
201
2010
2011
2012
2013
2014
2015
2016
2017
2019
2029
206
20oz
20s
20th
21
213
21444
215
21st
22
220
2200
220v
2218
222
2244
225
23
24
243
248
25
250
256
257
25th
26
268c
27
2700
271
279
28
29
29013
2ft
2nd
2pm
2s
3
30
300
3000
300m
30b
32
3200
32d
33
330
336
34
343
34333
35
35mm
36
360
36th
37
37306
38
381
39
3999999
39th
3d
3m
3pm
3rd
4
40
400
4000
40th
41
410
42
42nd
44
4445534
45
4

brewed
brewery
brewing
brewster
brick
bricked
bricklayer
bricklaying
bricks
brickwork
bridal
bride
brides
bridesmaid
bridesmaids
bridge
bridges
bridle
bridled
bridles
brief
briefcase
briefcases
briefs
brigade
bright
brighter
brightest
brightly
brightness
brights
brilliant
brim
brindle
bring
bringing
brings
brio
brisbane
briskly
bristle
bristled
bristles
bristol
britain
brite
british
brittle
bro
broach
broad
broadcast
broadcasted
broadcasting
broadcasts
broads
broadway
broccoli
broccolis
brochure
brochures
broiling
broke
broken
bronx
bronze
brooding
brook
brooklyn
brooks
broom
broomrape
brooms
bros
broth
brothel
brother
brothers
brought
brown
browned
brownie
brownies
browning
brownish
browns
brownville
browser
browsing
brrahhhh
bruce
bruise
bruised
bruises
bruising
brunch
brunet
brunette
brunettes
brush
brushed
brushes
brushing
brushy
brussel
brussels
bryant
bs
btu
bu
bubble
bubblegum
bubbles
buck
bucket
buckets
buckingham
buckle
buckled
buckles
buckskin
bucolic
bud
buddha
buddhism
budd

daisy
dakota
dali
dallas
dalmatian
dalmations
dam
damage
damaged
damages
dame
damp
damsel
dance
dancer
dancers
dances
dancing
dandelion
dandelions
dandruff
dane
danes
dang
danger
dangerous
dangerously
dangers
dangling
danielle
daniels
danish
danishes
danny
dapper
dapple
dappled
daredevil
daredevils
dark
darkened
darkening
darker
darkest
darkness
darrell
dart
darth
darts
dasani
dash
dashboard
dashed
dashes
data
date
dated
dates
dating
daughter
daughters
dave
david
davidson
davis
dawn
day
daybed
daybreak
daycare
daydreaming
daylight
daylilies
days
daytime
dayton
daytona
db
dc
ddr
de
dead
deadbolt
deadly
deaf
deal
dealer
dealers
dealership
dealing
deals
dean
death
deathly
deaths
debate
debbie
debit
deboarding
debonair
debris
decade
decades
decaf
decaffeinated
decal
decals
decanter
decanters
decapitated
decay
deceased
december
decent
decide
decided
decides
deciding
deciduous
decision
deck
decker
deckered
deckers
decks
declare
declawed
decline
declining
deco
decomposition
decor
decorate
dec

gearing
gears
gecko
geek
geese
gehrig
geico
geisha
geishas
gel
gelatin
gelatinous
gem
gems
gemstone
gemstones
gender
genders
gene
general
generally
generate
generated
generating
generation
generations
generic
generously
genetic
genetics
genie
genitalia
genitals
genius
genre
gentle
gentleman
gentlemen
gently
gentrified
gents
genuine
genus
geographic
geographical
geographically
geography
geologic
geological
geologically
geometric
geometrical
geometrically
geometry
george
geranium
gerberas
gerbil
gerbils
german
germany
germs
gerngross
gestation
gesture
gestures
gesturing
get
getaway
gets
getting
geyser
ghana
ghetto
ghost
ghostly
ghosts
giant
giants
gibson
gift
gifts
gigante
gigantic
gilded
gilligan
gilt
gin
ginger
gingerbread
gingy
giraffe
giraffes
girders
girl
girlfriend
girlfriends
girls
girly
give
given
gives
giving
gizmo
glacier
glad
gladiator
gladiators
glamor
glare
glares
glaring
glass
glassed
glasses
glassware
glassy
glaze
glazed
glazes
gleason
glendora
glide
glider
gliders
gliding

kong
korea
korean
kosher
krackel
kraken
kraut
kreme
kris
krispie
krispies
krispy
krups
kubrick
kumquats
kun
kutcher
l
la
lab
label
labeled
labels
labor
laboratory
laborer
labrador
lace
laced
laces
lacing
lack
lacking
lacoste
lacrosse
lactose
lacy
lad
ladder
ladders
laden
ladie
ladies
ladle
ladles
lads
lady
ladyboy
ladybug
ladylike
ladys
lagging
lagomarcino
laid
lain
lake
lakefront
lakes
lakeshore
lakeside
lamb
lamborghini
lambs
laminate
laminated
lamp
lamplighter
lamplights
lamppost
lampposts
lamps
lampshade
lampshades
lan
land
landed
landfill
landform
landforms
landing
landline
landlocked
landmark
landmarks
landmass
lands
landscape
landscaped
landscapes
landscaping
landslide
lane
lanes
language
languages
lantern
lanterns
lanyard
lanyards
lap
lapd
lapel
lapels
laps
lapse
lapsed
laptop
laptops
larceny
large
largely
larger
largest
las
lasagna
lasagne
laser
lassie
lasso
lassoed
last
lasted
lasting
latch
latched
latches
late
lately
later
laterally
latest
lather
lathered
latin
latino
latitu

opened
opener
opening
openings
opens
opentable
opera
operable
operate
operated
operates
operating
operation
operational
operations
operator
operators
opinion
opponent
opponents
opportunity
opposed
opposing
opposite
opposites
optical
optimal
optimist
optimistic
option
options
optometrist
opulent
or
oral
orange
oranges
orangutans
orb
orbach
orbiting
orbs
orchard
orchestra
orchestral
orchid
orchids
order
ordered
ordering
orderly
orders
ordinal
ordinarily
ordinary
ore
oregon
oreo
oreos
ores
organ
organic
organically
organism
organisms
organization
organize
organized
organizer
organizes
organs
orgy
oriental
orientated
orientation
orientations
oriented
orifice
origami
origin
original
originally
originate
originating
origins
orion
orleans
ornament
ornamental
ornaments
ornate
orphan
orphanage
orpheum
orthodox
orthopedic
os
oscar
osiris
ossicones
osterley
ostrich
ostriches
ostridge
ot
other
others
otherwise
otto
ottoman
ottomans
ought
ounces
our
out
outage
outboard
outbound
outburn
outdated
out

recline
recliner
recliners
reclining
recognizable
recognize
recognized
recommend
recommended
record
recorded
recorder
recorders
recording
records
recovering
recovers
recovery
recreate
recreating
recreation
recreational
recruiting
rectangle
rectangles
rectangular
recursive
recyclable
recyclables
recycle
recycled
recycling
red
reddish
redecorated
redhead
redheads
redneck
redoing
redone
reds
redstone
reduce
reduction
redundancy
redwood
reef
reel
reenactment
reenter
ref
refer
referee
reference
referenced
references
referencing
referred
referring
refers
refill
refilled
refilling
refills
refinished
reflect
reflected
reflecting
reflection
reflections
reflective
reflector
reflectors
reflects
reflex
reflexion
refreshed
refreshing
refreshment
refreshments
refride
refrigerated
refrigeration
refrigerator
refrigerators
refuel
refueled
refueling
refuge
refugees
refurbished
refurbishment
regain
regalia
regard
regarded
regarding
regards
reggae
region
regional
regions
register
registered
registering
re

stickball
sticker
stickers
sticking
sticks
sticky
stiff
stiles
stilettos
still
stiller
stilts
stink
stinky
stir
stirred
stirrer
stirring
stirrups
stitch
stitched
stitches
stitching
stix
stock
stocked
stocking
stockings
stockpiled
stocky
stockyard
stoic
stole
stolen
stomach
stomachs
stomp
stomping
stone
stoned
stonehenge
stones
stonework
stony
stood
stooges
stool
stools
stoop
stooped
stooping
stop
stoplight
stoplights
stopped
stopper
stoppers
stopping
stops
stopwatch
storage
store
stored
storefront
storefronts
storekeeper
stores
storey
storied
stories
storing
stork
storm
storming
stormtrooper
stormtroopers
stormy
story
stove
stoves
stovetop
stowaway
straddling
stragglers
straight
straighten
straightened
straightener
straightening
straightest
strain
strainer
straining
strait
stranded
strands
strange
strangely
stranger
strangers
strangest
strangle
strap
strapless
strapped
strapping
straps
strategically
strategy
strauss
straw
strawberries
strawberry
straws
stray
strays
streak
streaks
strea

vinyl
violated
violating
violation
violators
violence
violent
violently
violet
violets
violin
virgin
virginia
virginity
virtual
virtually
virtue
vis
visa
viscosity
visibility
visible
visibly
vision
visit
visited
visiting
visitor
visitors
visor
visors
vista
visual
visually
vitamin
vitamins
vivid
vocal
vocalizing
vocation
vodka
vogue
voice
void
voided
volatility
volcano
volkswagen
volkswagens
volley
volleyball
volleyballs
volleying
voltage
volts
volume
volunteer
vomit
vomiting
voodoo
vote
votes
voting
votives
vow
vowels
vows
voyager
voyages
vs
vulcan
vulgar
vulnerable
vulture
vultures
vw
w
wa
wacky
wade
wading
waffle
waffles
wage
wagging
wagon
wagons
wahlberg
wainscoting
waist
waistband
waists
wait
waiter
waiters
waiting
waitress
waits
wake
wakeboard
wakeboarder
wakeboarding
wakeboards
waking
wal
walden
waldo
walgreen
walgreens
walk
walkable
walked
walker
walkers
walking
walkman
walks
walkthrough
walkway
walkways
wall
wallace
walled
wallet
walling
wallingford
wallpaper
wallpapered
wallri

In [9]:
import torch
a = torch.arange(10).view(2,5)

In [10]:
a

tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])

In [13]:
_, i_1 = torch.max(a, 1)
_, i_2 = torch.max(a, 1)

In [14]:
i_1

tensor([4, 4])

In [15]:
i_2

tensor([4, 4])

In [12]:
a

tensor([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]])

In [16]:
with open('./demo/vocab_answers.txt', 'r') as f:
    for line in f.readlines():
        print(line)

<unk>

no

yes

2

1

white

3

red

black

blue

0

4

green

brown

yellow

5

gray

6

tennis

baseball

frisbee

nothing

orange

right

left

wood

pizza

bathroom

kitchen

pink

7

none

8

cat

dog

skiing

grass

man

water

10

silver

skateboarding

black and white

horse

kite

surfing

skateboard

giraffe

snow

tan

9

cake

surfboard

wii

broccoli

phone

living room

elephant

purple

stop

apple

12

table

soccer

food

woman

eating

sunny

banana

unknown

train

winter

sheep

hat

bus

standing

snowboarding

umbrella

male

motorcycle

beach

maybe

bear

cow

laptop

outside

wine

clear

female

zebra

camera

20

trees

many

walking

brick

metal

sitting

flowers

bedroom

bird

bench

tile

15

hot dog

night

summer

bed

11

plane

bananas

down

car

beige

sandwich

cloudy

cell phone

fork

up

beer

tree

kites

ground

donut

sand

chair

red and white

blue and white

boat

helmet

plate

glass

wall

horses

ball

bike

people

13

bat

girl

day
