In [None]:
# import libraries

import torch
from torch import nn
from fastai import * 
import pandas as pd
import numpy as np
from functools import partial
import io
import os
from fastai.text import *

## **Load Dataset**

In [None]:
path = untar_data(URLs.AG_NEWS)
path.ls()

[PosixPath('/root/.fastai/data/ag_news_csv/train.csv'),
 PosixPath('/root/.fastai/data/ag_news_csv/classes.txt'),
 PosixPath('/root/.fastai/data/ag_news_csv/readme.txt'),
 PosixPath('/root/.fastai/data/ag_news_csv/test.csv')]

In [None]:
df_train = pd.read_csv('/root/.fastai/data/ag_news_csv/train.csv')
df_train.head()

Unnamed: 0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
0,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
1,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
2,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
3,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."
4,3,"Stocks End Up, But Near Year Lows (Reuters)",Reuters - Stocks ended slightly higher on Frid...


In [None]:
df_valid = pd.read_csv('/root/.fastai/data/ag_news_csv/test.csv')
df_valid.head()

Unnamed: 0,3,Fears for T N pension after talks,Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.
0,4,The Race is On: Second Private Team Sets Launc...,"SPACE.com - TORONTO, Canada -- A second\team o..."
1,4,Ky. Company Wins Grant to Study Peptides (AP),AP - A company founded by a chemistry research...
2,4,Prediction Unit Helps Forecast Wildfires (AP),AP - It's barely dawn when Mike Fitzpatrick st...
3,4,Calif. Aims to Limit Farm-Related Smog (AP),AP - Southern California's smog-fighting agenc...
4,4,Open Letter Against British Copyright Indoctri...,The British Department for Education and Skill...


In [None]:
df_train.shape, df_valid.shape

((119999, 3), (7599, 3))

In [None]:
bs = 12

In [None]:
train = TextList(df_train, path=path)
valid = TextList(df_valid, path=path)

src = ItemLists(path=path, train=train, valid=valid).label_for_lm()
data = src.databunch(bs=bs)

In [None]:
data.show_batch()

idx,text
0,xxmaj stocks ' xxmaj outlook ( xxmaj reuters ) xxmaj reuters - xxmaj soaring crude prices plus worries \ about the economy and the outlook for earnings are expected to \ hang over the stock market next week during the depth of the \ summer doldrums . xxbos 3 xxmaj iraq xxmaj halts xxmaj oil xxmaj exports from xxmaj main xxmaj southern xxmaj pipeline ( xxmaj reuters ) xxmaj reuters
1,"xxmaj open source makes noise xxmaj developers forum looks at xxmaj linux for the disabled . xxmaj also : xxmaj microsoft chided for ad campaign . xxbos 4 xxmaj microsoft wraps up xxup mom 2005 management tool xxmaj microsoft xxmaj corp. on xxmaj wednesday said it has finished work on xxmaj microsoft xxmaj operations xxmaj manager ( xxup mom ) 2005 , a major update to its xxup mom 2000"
2,"- down sectors such as technology and insurers . xxbos 3 xxmaj monti backs bid for xxmaj abbey xxup eu xxup competition commissioner xxmaj mario xxmaj monti has backed xxmaj banco xxmaj santander xxmaj centrale xxmaj xxunk 8.75 billion bid for xxmaj abbey xxmaj national . xxmaj monti suggested support for a xxmaj spanish bid , saying : ' xxmaj at first glance , it is a contribution xxbos 3"
3,"satellite xxmaj radio xxmaj holdings xxmaj inc. ( xxup xmsr ) will soon begin broadcasting some of its stations to subscribers over the xxmaj internet , fresh on the heels of the company 's discontinuation of a receiver for pcs that some users used to circumvent the music industry 's crackdown on illegal file sharing . xxbos 4 xxup at t xxmaj wireless to xxmaj offer xxmaj xxunk xxmaj messaging"
4,"xxmaj president xxmaj hamid xxmaj karzai 's swearing - in ceremony , but the xxup us - led military said watertight security on the ground and in the air would stop any attack . xxbos 1 xxup un chief promises more staff to xxmaj iraq when possible xxup un xxmaj secretary - xxmaj general xxmaj kofi xxmaj annan assured the xxmaj iraqis that the world body will provide assistance to"


In [None]:
v = data.valid_ds.vocab

In [None]:
len(v.itos), len(v.stoi)

(37264, 75055)

In [None]:
n_text = len(v.itos) ; n_text

37264

In [None]:
len(data.train_dl), len(data.valid_dl)

(8055, 508)

## **Model**

Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

LSTM Module

In [None]:
batch_size =12

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        """
        Initialize the model by setting up the layers.
        """
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Embedding and LSTM layers
        self.embed = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        # Linear layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        """
        Perform a forward pass of our model on some input and hidden state.
        """
        out = self.embed(x)
        out, (hidden, cell) = self.lstm(out.unsqueeze(1), (hidden, cell))
        out = self.fc(out.reshape(out.shape[0], -1))
        return out, (hidden, cell)

    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        return hidden, cell

In [None]:
class Generator() : 
    def __init__(self) : 
        self.chunk_len = 70
        self.num_epochs = 200
        self.batch_size = 12
        self.print_every = 40
        self.hidden_size = 256
        self.num_layers = 2
        self.lr = 0.02
        self.len = len(data.train_dl)
    
    # To get random batch
    def random_batch(self) :
        text_input = torch.zeros( self.batch_size, self.chunk_len) #x,y = next(iter(data.train_dl)) ; len(x) = 12
        text_target = torch.zeros(  self.batch_size , self.chunk_len)  #x,y = next(iter(data.train_dl)) ; len(x) = 12
      
        text_input, text_target =  next(iter(data.train_dl))

        return text_input.to(device).long(), text_target.to(device).long()  

    # Generate Function
    def generate(self, initial_str=' today the news is ', predict_len = 100, temperature = 0.85) :
        hidden, cell = self.rnn.init_hidden( batch_size = self.batch_size)

        initial_input = string_to_int(initial_str)
        predicted = initial_str

        """for p in range(len(initial_str) - 1):
            _, (hidden, cell) = self.rnn( initial_input[p], hidden, cell )

        last_char = initial_input[-1]"""
        torch_idx = torch.tensor(initial_input)
        for p in range(predict_len):
            output, (hidden, cell) = self.rnn( torch_idx.to(device) , hidden, cell  )
            output_dist = output.data.view(-1).div(temperature).exp()
            top_str = torch.multinomial(output_dist, 1)[0]
            predicted_char = v.itos[top_str]
            predicted += predicted_char
            last_char = v.stoi[predicted_char]

        return predicted

    #Train Function
    def train(self) : 
        self.rnn = RNN(n_text, self.hidden_size, self.num_layers, n_text  ).to(device)

        optimizer = torch.optim.Adam( self.rnn.parameters(), lr = self.lr) 
        criterion = nn.CrossEntropyLoss()

        print( " Starting Training : ")
        inp, target = self.random_batch()
        for epoch in range(1, self.num_epochs + 1) :
           
            hidden, cell = self.rnn.init_hidden(batch_size = self.batch_size)
            self.rnn.zero_grad()
            loss = 0
            inp = inp.to(device)
            target = target.to(device)
            correct = 0
            total = 0
            for i in range(self.chunk_len) :
                output, (hidden, cell) = self.rnn(inp[ :,i], hidden, cell)
                loss += criterion(output, target[ :, i]) 
                """print("Output Shape :", output[:, i].shape, output.shape)
                print(" Target Shape : ", target[:,i].shape)

                if ( output.data[:,i] == target[:, i]) :          
                    correct +=1"""
                #correct += (output.data[:,i] == target[:, i]).sum()
                #correct += torch.eq(output[i :, c], target[ i :c])
                total += len(output)

            l = correct/ total
            accuracy = 100 * l
            loss.backward()
            optimizer.step()
            loss = loss.item() / self.chunk_len

            if epoch % self.print_every == 0 :
                print(f"Iteration: {epoch}")
                print(f"Loss: {loss}")
                #print(f"Accuracy: {accuracy}")
                print(self.generate())


In [None]:
getNews = Generator()
getNews.train()

 Starting Training : 
Iteration: 40
Loss: 5.009692818777902
Iteration: 80
Loss: 4.122496686662946
Iteration: 120
Loss: 3.261800275530134
Iteration: 160
Loss: 2.648405020577567
Iteration: 200
Loss: 1.971807861328125
