-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
118 lines (90 loc) · 2.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from unicodedata import decimal
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from dataset import CarvanaDataset
from torch.utils.data import Dataset, DataLoader
from utils import (
load_checkpoint,
save_checkpoint,
check_accuracy,
save_predictions_as_imgs,
)
#Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE= 32
NUM_EPOCH= 10
NUM_WORKERS= 2
IMAGE_HEIGHT= 160
IMAGE_WIDTH= 240
PIN_MEMORY= True
LOAD_MODEL = True
IMG_DIR= "data\train"
MASK_DIR= "data\train_masks"
def train_fn(loader, model, optimizer, loss_fn, scaler):
loop = tqdm(loader)
for batch_idx, (data, targets) in enumerate(loop):
data= data.to(device= DEVICE).float()
targets= targets.float().unsqueeze(1).to(device = DEVICE)
#forward
with torch.cuda.amp.autocast():
predictions= model(data)
loss= loss_fn(predictions, targets)
#backward
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
#update tqdm loop
loop.set_postfix(loss= loss.item())
def main():
train_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
ToTensorV2()
]
)
val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
model= UNET(in_channels=3, out_channel=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr= LEARNING_RATE)
dataset = CarvanaDataset(transform=train_transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# print(train_dataset)
train_loader= DataLoader(dataset= train_dataset, batch_size=10, shuffle=True)
test_loader= DataLoader(dataset= test_dataset, batch_size=10, shuffle=True)
#training the model
scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCH):
#train_fn(train_loader, model, optimizer, loss_fn, scaler)
# save model
# checkpoint = {
# "state_dict": model.state_dict(),
# "optimizer":optimizer.state_dict(),
# }
#save_checkpoint(checkpoint)
#check accuracy
check_accuracy(test_loader, model, device=DEVICE)
# print some examples to a folder
save_predictions_as_imgs(
test_loader, model, folder="saved_images/", device=DEVICE
)
if __name__ == "__main__":
main()