<img src="https://i.imgur.com/fOeJmu3.jpg" width="600px">

# Introduction

Hello everyone! In this kernel, I will try to finetune roBERTa base to predict sentiments instead of selected_text. I believe that **explicitly teaching the model to understand the sentiment of a tweet can help it understand which part of the tweet expresses that sentiment**. So, I have decided to undertake a double-finetuning procesdure where I finetune roBERTa base on sentiment classification and then finetune the finetuned model on sentiment extraction (the competition goal). I feel this process might improve the ability of the model to predict selected_text correctly.

This is part one of the approach; I will train the sentiment extraction model in part two :D

## Set up PyTorch-XLA

In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!export XLA_USE_BF16=1



## Install and import libraries

In [2]:
!pip install -q colored
!pip install -q transformers

In [3]:
import os
import gc
import re

import time
import colored
import numpy as np
import pandas as pd
from colored import fg, bg, attr

import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm
from sklearn.utils import shuffle
from transformers import RobertaModel, RobertaTokenizer

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences as pad

Using TensorFlow backend.


## Define hyperparameters and load data

In [4]:
EPOCHS = 8
SPLIT = 0.8
MAXLEN = 48
DROP_RATE = 0.225

OUTPUT_UNITS = 3
BATCH_SIZE = 384
LR = (4e-5, 1e-2)
ROBERTA_UNITS = 768
VAL_BATCH_SIZE = 384
MODEL_SAVE_PATH = 'sentiment_model.pt'

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [5]:
test_df = pd.read_csv('../input/tweet-sentiment-extraction/test.csv')
train_df = pd.read_csv('../input/tweet-sentiment-extraction/train.csv')

In [6]:
test_df.head()

Unnamed: 0,textID,text,sentiment
0,f87dea47db,Last session of the day http://twitpic.com/67ezh,neutral
1,96d74cb729,Shanghai is also really exciting (precisely -...,positive
2,eee518ae67,"Recession hit Veronique Branquinho, she has to...",negative
3,01082688c6,happy bday!,positive
4,33987a8ee5,http://twitpic.com/4w75p - I like it!!,positive


In [7]:
train_df.head()

Unnamed: 0,textID,text,selected_text,sentiment
0,cb774db0d1,"I`d have responded, if I were going","I`d have responded, if I were going",neutral
1,549e992a42,Sooo SAD I will miss you here in San Diego!!!,Sooo SAD,negative
2,088c60f138,my boss is bullying me...,bullying me,negative
3,9642c003ef,what interview! leave me alone,leave me alone,negative
4,358bd9e861,"Sons of ****, why couldn`t they put them on t...","Sons of ****,",negative


## Define PyTorch Dataset

In [8]:
class TweetDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.text = data.text
        self.tokenizer = tokenizer
        self.sentiment = data.sentiment
        self.sentiment_dict = {"positive": 0, "neutral": 1, "negative": 2}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        pg, tg = 'post', 'post'
        tweet = str(self.text[i]).strip()
        tweet_ids = self.tokenizer.encode(tweet)

        attention_mask_idx = len(tweet_ids) - 1
        if 0 not in tweet_ids: tweet_ids = 0 + tweet_ids
        tweet_ids = pad([tweet_ids], maxlen=MAXLEN, value=1, padding=pg, truncating=tg)

        attention_mask = np.zeros(MAXLEN)
        attention_mask[1:attention_mask_idx] = 1
        attention_mask = attention_mask.reshape((1, -1))
        if 2 not in tweet_ids: tweet_ids[-1], attention_mask[-1] = 2, 0
            
        sentiment = [self.sentiment_dict[self.sentiment[i]]]
        sentiment = torch.FloatTensor(to_categorical(sentiment, num_classes=3))
        return sentiment, torch.LongTensor(tweet_ids), torch.LongTensor(attention_mask)

## Define roBERTa-base model

In [9]:
class Roberta(nn.Module):
    def __init__(self):
        super(Roberta, self).__init__()
        self.softmax = nn.Softmax(dim=1)
        self.drop = nn.Dropout(DROP_RATE)
        self.roberta = RobertaModel.from_pretrained(model)
        self.dense = nn.Linear(ROBERTA_UNITS, OUTPUT_UNITS)
        
    def forward(self, inp, att):
        inp = inp.view(-1, MAXLEN)
        _, self.feat = self.roberta(inp, att)
        return self.softmax(self.dense(self.drop(self.feat)))

## Split data and define tokenizer

In [10]:
train_df = shuffle(train_df)
split = np.int32(SPLIT*len(train_df))
val_df, train_df = train_df[split:], train_df[:split]

In [11]:
model = 'roberta-base'
tokenizer = RobertaTokenizer.from_pretrained(model)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




## Define cross entropy and accuracy

In [12]:
def cel(inp, target):
    _, labels = target.max(dim=1)
    return nn.CrossEntropyLoss()(inp, labels)

def accuracy(inp, target):
    inp_ind = inp.max(axis=1).indices
    target_ind = target.max(axis=1).indices
    return (inp_ind == target_ind).float().sum(axis=0)/len(inp_ind)

## Train model on TPU

In [13]:
def print_metric(data, batch, epoch, start, end, metric, typ):

    cs = [211, 212, 213]
    time = np.round(end - start, 1)
    time = "Time: %s{}%s s".format(time)
    fonts = [(fg(c), attr('reset')) for c in cs]
    if typ == "Train": pre = "BATCH %s" + str(batch-1) + "%s  "
    if typ == "Val": pre = "\nEPOCH %s" + str(epoch+1) + "%s  "

    print(pre % fonts[0] , end='')
    t = typ, metric, "%s", data, "%s"
    print("{} {}: {}{}{}".format(*t) % fonts[1] + "  " + time % fonts[2])

In [14]:
device = xm.xla_device()

val_df = val_df.reset_index(drop=True)
val_set = TweetDataset(val_df, tokenizer)
val_loader = DataLoader(val_set, batch_size=VAL_BATCH_SIZE)

train_df = train_df.reset_index(drop=True)
train_set = TweetDataset(train_df, tokenizer)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    
network = Roberta().to(device)
optimizer = Adam([{'params': network.dense.parameters(), 'lr': LR[1]},
                  {'params': network.roberta.parameters(), 'lr': LR[0]}])

val_losses, val_accuracies = [], []
train_losses, train_accuracies = [], []
    
start = time.time()
print("STARTING TRAINING ...\n")

for epoch in range(EPOCHS):

    batch = 1
    network.train()
    fonts = (fg(48), attr('reset'))
    print(("EPOCH %s" + str(epoch+1) + "%s") % fonts)
        
    for train_batch in train_loader:
        train_targ, train_in, train_att = train_batch
            
        network = network.to(device)
        train_in = train_in.to(device)
        train_att = train_att.to(device)
        train_targ = train_targ.to(device)

        train_preds = network.forward(train_in, train_att)
        train_loss = cel(train_preds, train_targ.squeeze(dim=1))
        train_accuracy = accuracy(train_preds, train_targ.squeeze(dim=1))

        optimizer.zero_grad()
        train_loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
            
        end = time.time()
        batch = batch + 1
        acc = np.round(train_accuracy.item(), 3)
        print_metric(acc, batch, None, start, end, metric="acc", typ="Train")

    network.eval()
    val_loss, val_accuracy = 0, 0
    for val_batch in tqdm(val_loader):
        targ, val_in, val_att = val_batch

        with torch.no_grad():
            targ = targ.to(device)
            val_in = val_in.to(device)
            val_att = val_att.to(device)
            network = network.to(device)

            pred = network.forward(val_in, val_att)
            val_loss += cel(pred, targ.squeeze(dim=1)).item()*len(pred)
            val_accuracy += accuracy(pred, targ.squeeze(dim=1)).item()*len(pred)
        
    end = time.time()
    val_loss /= len(val_set)
    val_accuracy /= len(val_set)
    acc = np.round(val_accuracy, 3)
    print_metric(acc, None, epoch, start, end, metric="acc", typ="Val")
    
    print("")
    val_losses.append(val_loss); train_losses.append(train_loss)
    val_accuracies.append(val_accuracy); train_accuracies.append(train_accuracy)
    
print("ENDING TRAINING ...")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=481.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501200538.0, style=ProgressStyle(descri…


STARTING TRAINING ...

EPOCH [38;5;48m1[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.299[0m  Time: [38;5;213m44.3[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.372[0m  Time: [38;5;213m92.5[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.409[0m  Time: [38;5;213m93.5[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.427[0m  Time: [38;5;213m94.5[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.391[0m  Time: [38;5;213m95.5[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.396[0m  Time: [38;5;213m96.4[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.401[0m  Time: [38;5;213m97.4[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.404[0m  Time: [38;5;213m98.4[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.359[0m  Time: [38;5;213m99.3[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.404[0m  Time: [38;5;213m100.3[0m s
BATCH [38;5;211m11[0m  Train acc: [38;5;212m0.359[0m  Time: [38;5;213m101.3[0m s
BATCH [38;5;21

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m1[0m  Val acc: [38;5;212m0.687[0m  Time: [38;5;213m221.0[0m s

EPOCH [38;5;48m2[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.677[0m  Time: [38;5;213m290.3[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.646[0m  Time: [38;5;213m291.3[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.685[0m  Time: [38;5;213m292.2[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.701[0m  Time: [38;5;213m293.1[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.628[0m  Time: [38;5;213m294.0[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.68[0m  Time: [38;5;213m294.9[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.708[0m  Time: [38;5;213m295.8[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.648[0m  Time: [38;5;213m296.7[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.555[0m  Time: [38;5;213m297.6[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.578[0m  Time: [38;5;213m298.4[0m s
BATCH [38;5;211m11[0m  Train 

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m2[0m  Val acc: [38;5;212m0.732[0m  Time: [38;5;213m349.6[0m s

EPOCH [38;5;48m3[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.745[0m  Time: [38;5;213m349.9[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.747[0m  Time: [38;5;213m350.9[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.771[0m  Time: [38;5;213m351.8[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.74[0m  Time: [38;5;213m352.7[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.693[0m  Time: [38;5;213m353.6[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.701[0m  Time: [38;5;213m354.5[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.727[0m  Time: [38;5;213m355.3[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.706[0m  Time: [38;5;213m356.2[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.711[0m  Time: [38;5;213m357.1[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.753[0m  Time: [38;5;213m358.0[0m s
BATCH [38;5;211m11[0m  Train 

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m3[0m  Val acc: [38;5;212m0.762[0m  Time: [38;5;213m409.0[0m s

EPOCH [38;5;48m4[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.766[0m  Time: [38;5;213m409.3[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.768[0m  Time: [38;5;213m410.2[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.784[0m  Time: [38;5;213m411.1[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.758[0m  Time: [38;5;213m412.0[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.792[0m  Time: [38;5;213m412.9[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.74[0m  Time: [38;5;213m413.8[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.781[0m  Time: [38;5;213m414.7[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.784[0m  Time: [38;5;213m415.6[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.76[0m  Time: [38;5;213m416.5[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.773[0m  Time: [38;5;213m417.3[0m s
BATCH [38;5;211m11[0m  Train a

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m4[0m  Val acc: [38;5;212m0.771[0m  Time: [38;5;213m468.1[0m s

EPOCH [38;5;48m5[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.812[0m  Time: [38;5;213m468.4[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.74[0m  Time: [38;5;213m469.4[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.81[0m  Time: [38;5;213m470.3[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.768[0m  Time: [38;5;213m471.2[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m472.1[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.807[0m  Time: [38;5;213m472.9[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.807[0m  Time: [38;5;213m473.8[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.789[0m  Time: [38;5;213m474.7[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.781[0m  Time: [38;5;213m475.6[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m476.5[0m s
BATCH [38;5;211m11[0m  Train a

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m5[0m  Val acc: [38;5;212m0.772[0m  Time: [38;5;213m527.3[0m s

EPOCH [38;5;48m6[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m527.6[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.812[0m  Time: [38;5;213m528.5[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.833[0m  Time: [38;5;213m529.4[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.805[0m  Time: [38;5;213m530.3[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.792[0m  Time: [38;5;213m531.2[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.812[0m  Time: [38;5;213m532.1[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.807[0m  Time: [38;5;213m533.0[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.771[0m  Time: [38;5;213m533.9[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.781[0m  Time: [38;5;213m534.8[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.766[0m  Time: [38;5;213m535.6[0m s
BATCH [38;5;211m11[0m  Train

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m6[0m  Val acc: [38;5;212m0.775[0m  Time: [38;5;213m586.4[0m s

EPOCH [38;5;48m7[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m586.7[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.818[0m  Time: [38;5;213m587.7[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m588.6[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.807[0m  Time: [38;5;213m589.5[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.812[0m  Time: [38;5;213m590.4[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.81[0m  Time: [38;5;213m591.3[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.799[0m  Time: [38;5;213m592.2[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.797[0m  Time: [38;5;213m593.1[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.745[0m  Time: [38;5;213m594.0[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.784[0m  Time: [38;5;213m594.9[0m s
BATCH [38;5;211m11[0m  Train 

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m7[0m  Val acc: [38;5;212m0.783[0m  Time: [38;5;213m645.7[0m s

EPOCH [38;5;48m8[0m
BATCH [38;5;211m1[0m  Train acc: [38;5;212m0.755[0m  Time: [38;5;213m646.0[0m s
BATCH [38;5;211m2[0m  Train acc: [38;5;212m0.833[0m  Time: [38;5;213m646.9[0m s
BATCH [38;5;211m3[0m  Train acc: [38;5;212m0.802[0m  Time: [38;5;213m647.8[0m s
BATCH [38;5;211m4[0m  Train acc: [38;5;212m0.779[0m  Time: [38;5;213m648.7[0m s
BATCH [38;5;211m5[0m  Train acc: [38;5;212m0.805[0m  Time: [38;5;213m649.6[0m s
BATCH [38;5;211m6[0m  Train acc: [38;5;212m0.766[0m  Time: [38;5;213m650.5[0m s
BATCH [38;5;211m7[0m  Train acc: [38;5;212m0.805[0m  Time: [38;5;213m651.4[0m s
BATCH [38;5;211m8[0m  Train acc: [38;5;212m0.789[0m  Time: [38;5;213m652.3[0m s
BATCH [38;5;211m9[0m  Train acc: [38;5;212m0.818[0m  Time: [38;5;213m653.2[0m s
BATCH [38;5;211m10[0m  Train acc: [38;5;212m0.789[0m  Time: [38;5;213m654.1[0m s
BATCH [38;5;211m11[0m  Train

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



EPOCH [38;5;211m8[0m  Val acc: [38;5;212m0.781[0m  Time: [38;5;213m704.8[0m s

ENDING TRAINING ...


## Check testing performance

In [17]:
test_set = TweetDataset(test_df, tokenizer)
test_loader = DataLoader(test_set, batch_size=VAL_BATCH_SIZE)

network.eval()
for test_batch in tqdm(test_loader):
    test_preds, test_targs = [], []
    targ, test_in, test_att = test_batch

    with torch.no_grad():
        network = network.to(device)
        test_in = test_in.to(device)
        test_att = test_att.to(device)
        pred = network.forward(test_in, test_att)
        test_preds.append(pred); test_targs.append(targ)

test_preds = torch.cat(test_preds, axis=0)
test_targs = torch.cat(test_targs, axis=0)
test_accuracy = accuracy(test_preds, test_targs.squeeze(dim=1).to(device))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [18]:
fonts = (fg(212), attr('reset'))
acc = np.round(test_accuracy.item()*100, 2)
print("{}: {}{}{}".format("Test acc", "%s", str(acc), "%s") % fonts + " %")

Test acc: [38;5;212m91.03[0m %


## Sample sentiment prediction

In [19]:
def predict_sentiment(tweet):
    pg, tg = 'post', 'post'
    tweet_ids = tokenizer.encode(tweet.strip())
    sent = {0: 'positive', 1: 'neutral', 2: 'negative'}

    att_mask_idx = len(tweet_ids) - 1
    if 0 not in tweet_ids: tweet_ids = 0 + tweet_ids
    tweet_ids = pad([tweet_ids], maxlen=MAXLEN, value=1, padding=pg, truncating=tg)

    att_mask = np.zeros(MAXLEN)
    att_mask[1:att_mask_idx] = 1
    att_mask = att_mask.reshape((1, -1))
    if 2 not in tweet_ids: tweet_ids[-1], att_mask[-1] = 2, 0
    tweet_ids, att_mask = torch.LongTensor(tweet_ids), torch.LongTensor(att_mask)
    return sent[np.argmax(network.forward(tweet_ids.to(device), att_mask.to(device)).detach().cpu().numpy())]

In [20]:
predict_sentiment("It was okay I guess?")

'neutral'

In [21]:
predict_sentiment("I want to hide omg !!!")

'neutral'

In [22]:
predict_sentiment("I am feeling great today ...")

'positive'

## Save model for PART 2

In [23]:
network = network.cpu()
torch.save(network.state_dict(), MODEL_SAVE_PATH)