In [1]:
import sys
import torch
from torch import nn
import torchvision.models as models
import time
from tqdm import tqdm
import torch.optim as optim
import copy

# Specify where to find the data preparation class
sys.path.append('../../Data_Preparation')
from Preparation import CustomDataLoader

In [2]:
# InceptionV3 training data (ImageNet) properties
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DIMENSIONS = 3
BATCH_SIZE = 16

LR = 0.1
MOMENTUM=0.9
WEIGHT_DECAY = 1e-4
LR_STEP_SIZE = 30
LR_GAMMA = 0.1

EPOCHS = 90

In [3]:
print("Creating data loaders")
# Instantiate the CustomDataLoader class for training
train_data_loader = CustomDataLoader(data_path="../../FER2013_Data", batch_size=BATCH_SIZE, dataset_type="train", mean=MEAN, std=STD, dimensions=3).data_loader
test_data_loader = CustomDataLoader(data_path="../../FER2013_Data", batch_size=BATCH_SIZE, dataset_type="test", mean=MEAN, std=STD, dimensions=3).data_loader

# Confirm correct data load
print("Train Data Loader:")
for batch_idx, (inputs, labels) in enumerate(train_data_loader):
    print("Batch Index:", batch_idx)
    print("Inputs Shape:", inputs.shape)
    print("Labels Shape:", labels.shape)
    # Print the first few labels in the batch
    print("Labels:", labels[:5])
    # Break after printing a few batches
    if batch_idx == 2:
        break

Creating data loaders
Train Data Loader:
Batch Index: 0
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([4, 4, 3, 2, 6])
Batch Index: 1
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([3, 0, 3, 0, 3])
Batch Index: 2
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([5, 6, 4, 6, 2])


In [7]:
# load up the InceptionV3 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.inception_v3(pretrained=True)
model.eval()

for parameter in model.parameters():
    parameter.requires_grad = False

# Replace the last fully connected layer with a new one that outputs 7 classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 7)  # Output layer with 7 classes
model.aux_logits = False
model.AuxLogits = None

model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)

#torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
print("[INFO] Start training")
print("---------------------")

for epoch in tqdm(range(EPOCHS)):

    total_batch = len(train_data_loader.dataset)//BATCH_SIZE
    
    # Training phase
    model.train()  # Set the model to train mode

    # Get statistics
    epoch_loss = 0
    len_dataset = 0
    
    for step, (batch_images, batch_labels) in enumerate(train_data_loader):
        X, Y = batch_images.to(device), batch_labels.to(device)

        pred = model(X)
        loss = criterion(pred, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += pred.shape[0] * loss.item()
        len_dataset += pred.shape[0]
        if (step) % 10 == 0:
            print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                 %(epoch+1, EPOCHS, step+1, total_batch, loss.item()))

    epoch_loss = epoch_loss/ len_dataset
    print('Epoch: ', epoch+1, '| train loss : %0.4f' % epoch_loss)

    LR_SCHEDULER.step()
    
    # Validation phase
    model.eval()  # Set the model to evaluation mode
    with torch.inference_mode():  
        running_loss = 0
        for step, (images, labels) in enumerate(test_data_loader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            
    running_loss = running_loss / len(test_data_loader)
    print('Epoch: ', epoch, '| test loss : %0.4f' % running_loss )


[INFO] Start training
---------------------


  0%|          | 0/90 [00:00<?, ?it/s]

Epoch [1/90], lter [1/1752] Loss: 24.5209
Epoch [1/90], lter [11/1752] Loss: 20.5382
Epoch [1/90], lter [21/1752] Loss: 29.8732
Epoch [1/90], lter [31/1752] Loss: 27.7454
Epoch [1/90], lter [41/1752] Loss: 26.2988
Epoch [1/90], lter [51/1752] Loss: 20.3020
Epoch [1/90], lter [61/1752] Loss: 13.5113
Epoch [1/90], lter [71/1752] Loss: 34.0717
Epoch [1/90], lter [81/1752] Loss: 26.3085
Epoch [1/90], lter [91/1752] Loss: 35.8470
Epoch [1/90], lter [101/1752] Loss: 19.8037
Epoch [1/90], lter [111/1752] Loss: 26.9462
Epoch [1/90], lter [121/1752] Loss: 20.1170
Epoch [1/90], lter [131/1752] Loss: 18.9827
Epoch [1/90], lter [141/1752] Loss: 26.6137
Epoch [1/90], lter [151/1752] Loss: 29.6315
Epoch [1/90], lter [161/1752] Loss: 20.1835
Epoch [1/90], lter [171/1752] Loss: 27.0698
Epoch [1/90], lter [181/1752] Loss: 25.5315
Epoch [1/90], lter [191/1752] Loss: 28.0733
Epoch [1/90], lter [201/1752] Loss: 23.2970
Epoch [1/90], lter [211/1752] Loss: 30.7603
Epoch [1/90], lter [221/1752] Loss: 32.9085

  1%|          | 1/90 [53:35<79:29:07, 3215.15s/it]

Epoch:  0 | test loss : 19.7356
Epoch [2/90], lter [1/1752] Loss: 37.5264
Epoch [2/90], lter [11/1752] Loss: 15.1287
Epoch [2/90], lter [21/1752] Loss: 36.8454
Epoch [2/90], lter [31/1752] Loss: 12.7344
Epoch [2/90], lter [41/1752] Loss: 26.2758
Epoch [2/90], lter [51/1752] Loss: 26.8353
Epoch [2/90], lter [61/1752] Loss: 38.3289
Epoch [2/90], lter [71/1752] Loss: 29.4488
Epoch [2/90], lter [81/1752] Loss: 24.5433
Epoch [2/90], lter [91/1752] Loss: 27.0105
Epoch [2/90], lter [101/1752] Loss: 17.0444
Epoch [2/90], lter [111/1752] Loss: 35.8786
Epoch [2/90], lter [121/1752] Loss: 28.0786
Epoch [2/90], lter [131/1752] Loss: 21.9039
Epoch [2/90], lter [141/1752] Loss: 20.4429
Epoch [2/90], lter [151/1752] Loss: 33.6597
Epoch [2/90], lter [161/1752] Loss: 26.4655
Epoch [2/90], lter [171/1752] Loss: 36.5336
Epoch [2/90], lter [181/1752] Loss: 29.4748
Epoch [2/90], lter [191/1752] Loss: 24.1091
Epoch [2/90], lter [201/1752] Loss: 28.1159
Epoch [2/90], lter [211/1752] Loss: 24.0201
Epoch [2/90

  2%|▏         | 2/90 [1:47:24<78:48:13, 3223.79s/it]

Epoch:  1 | test loss : 21.0203
Epoch [3/90], lter [1/1752] Loss: 26.4743
Epoch [3/90], lter [11/1752] Loss: 22.4584
Epoch [3/90], lter [21/1752] Loss: 13.2419
Epoch [3/90], lter [31/1752] Loss: 24.3960
Epoch [3/90], lter [41/1752] Loss: 40.9223
Epoch [3/90], lter [51/1752] Loss: 26.5597
Epoch [3/90], lter [61/1752] Loss: 21.4749
Epoch [3/90], lter [71/1752] Loss: 19.0224
Epoch [3/90], lter [81/1752] Loss: 24.0370
Epoch [3/90], lter [91/1752] Loss: 31.8756
Epoch [3/90], lter [101/1752] Loss: 13.6559
Epoch [3/90], lter [111/1752] Loss: 24.5176
Epoch [3/90], lter [121/1752] Loss: 24.2451
Epoch [3/90], lter [131/1752] Loss: 32.2007
Epoch [3/90], lter [141/1752] Loss: 23.5167
Epoch [3/90], lter [151/1752] Loss: 17.9500
Epoch [3/90], lter [161/1752] Loss: 19.5583
Epoch [3/90], lter [171/1752] Loss: 29.4886
Epoch [3/90], lter [181/1752] Loss: 21.7500
Epoch [3/90], lter [191/1752] Loss: 16.2226
Epoch [3/90], lter [201/1752] Loss: 25.2413
Epoch [3/90], lter [211/1752] Loss: 19.7464
Epoch [3/90

In [None]:
# Saving the model
torch.save(model.state_dict(), 'trained_inception_v3.pt')

# GARBAGE

In [None]:
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(5)):
	# set the model in training mode
	model.train()
	# initialize the total training and validation loss
	totalTrainLoss = 0
	totalValLoss = 0
	# initialize the number of correct predictions in the training
	# and validation step
	trainCorrect = 0
	valCorrect = 0
	# loop over the training set
	for (step, (x, y)) in enumerate(train_data_loader):
		# send the input to the device
		#(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
		# perform a forward pass and calculate the training loss
		pred = model(x)
		loss = lossFunc(pred, y)
		# calculate the gradients
		loss.backward()
		# check if we are updating the model parameters and if so
		# update them, and zero out the previously accumulated gradients
		if (step + 2) % 2 == 0:
			opt.step()
			opt.zero_grad()
		# add the loss to the total training loss so far and
		# calculate the number of correct predictions
		totalTrainLoss += loss
		trainCorrect += (pred.argmax(1) == y).type(
			torch.float).sum().item()

    	# switch off autograd
	with torch.no_grad():
		# set the model in evaluation mode
		model.eval()
		# loop over the validation set
		for (x, y) in test_data_loader:
			# send the input to the device
			#(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
			# make the predictions and calculate the validation loss
			pred = model(x)
			totalValLoss += lossFunc(pred, y)
			# calculate the number of correct predictions
			valCorrect += (pred.argmax(1) == y).type(
				torch.float).sum().item()

    	# calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgValLoss = totalValLoss / valSteps
	# calculate the training and validation accuracy
	trainCorrect = trainCorrect / len(trainDS)
	valCorrect = valCorrect / len(valDS)
	# update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["train_acc"].append(trainCorrect)
	H["val_loss"].append(avgValLoss.cpu().detach().numpy())
	H["val_acc"].append(valCorrect)
	# print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, 5))
	print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
		avgTrainLoss, trainCorrect))
	print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
		avgValLoss, valCorrect))

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))
# plot the training loss and accuracy
# plt.style.use("ggplot")
# plt.figure()
# plt.plot(H["train_loss"], label="train_loss")
# plt.plot(H["val_loss"], label="val_loss")
# plt.plot(H["train_acc"], label="train_acc")
# plt.plot(H["val_acc"], label="val_acc")
# plt.title("Training Loss and Accuracy on Dataset")
# plt.xlabel("Epoch #")
# plt.ylabel("Loss/Accuracy")
# plt.legend(loc="lower left")
# plt.savefig(config.WARMUP_PLOT)
# serialize the model to disk
torch.save(model, config.WARMUP_MODEL)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0
            # Iterate over data.
            for inputs, labels in train_data_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()
            epoch_loss = running_loss / dataset_sizes
            epoch_acc = running_corrects.double() / dataset_sizes
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            # deep copy the model
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)