In [35]:
import os
import sys
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm

%run AlexNet.ipynb

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ])
    }
    
    data_root = os.path.abspath(os.path.join(os.getcwd(), '../'))
    image_path = os.path.join(data_root, "data", "flower_data")
    assert os.path.exists(image_path),"{} path does not exists.".format(image_path)
    print(data_root, image_path)
    
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"])
    train_num = len(train_dataset)
    print(train_num)
    
    flower_list = train_dataset.class_to_idx
    print(flower_list)
    cla_dict = dict((val, key) for key, val in flower_list.items())
    print(cla_dict)
    
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
        
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloader workers every process'.format(nw))
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
    val_num = len(validate_dataset)
    validate_loader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    
    net = AlexNet()
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()           
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()            
            running_loss += loss.item()            
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
            
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        
        if val_accurate > best_acc :
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
            
    print("Finished Training")

if __name__ == "__main__":
    main()

conv out: torch.Size([2, 128, 6, 6])
alex out: torch.Size([2, 5])
using cpu device.
E:\codes\nets_pytorch E:\codes\nets_pytorch\data\flower_data
3306
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
Using 8 dataloader workers every process
using 3306 images for training, 364 images for validation.
conv out: torch.Size([2, 128, 6, 6])
train epoch[1/10] loss:1.202: 100%|██████████████████████████████████████████████████| 104/104 [02:25<00:00,  1.40s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:10<00:00,  1.19it/s]
[epoch 1] train_loss: 1.456  val_accuracy: 0.434
train epoch[2/10] loss:1.359: 100%|██████████████████████████████████████████████████| 104/104 [02:05<00:00,  1.21s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:10<00:00,  1.16it/s]
[epoch 2] train_loss: 1.264  val_accuracy: