-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
107 lines (83 loc) · 2.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from result_viz import print_examples, plot_examples
from data_utils import get_loader
from model import MyModel
def save_checkpoint(state, filename="models/my_checkpoint_1_2048.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
step = checkpoint["step"]
return step
transform = transforms.Compose(
[
transforms.Resize((356, 356)),
transforms.RandomCrop((300, 300)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_loader, dataset = get_loader(
root_folder="data/Images",
annotation_file="data/captions.txt",
transform=transform,
num_workers=4,
)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
save_model = True
train_CNN = False
# Hyperparameters
embed_size = 2048
hidden_size = 512
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs = 5
# for tensorboard
writer = SummaryWriter("runs/flickr")
step = 0
# initialize model, loss and Adam optimizer
model = MyModel(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Only finetune the CNN
for name, param in model.encoderCNN.inception.named_parameters():
if "fc.weight" in name or "fc.bias" in name:
param.requires_grad = True
else:
param.requires_grad = train_CNN
if load_model:
step = load_checkpoint(torch.load("models/my_checkpoint_1_2048.pth.tar"), model, optimizer)
model.train()
for epoch in tqdm(range(num_epochs), desc="Epochs: "):
print_examples(model, device, dataset)
if save_model and epoch != 0:
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step,
}
save_checkpoint(checkpoint)
for idx, (imgs, captions) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions[:-1])
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
writer.add_scalar("Training loss", loss.item(), global_step=step)
step += 1
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
print_examples(model, device, dataset)
plot_examples(model, device, dataset, "2048_1")