# Naive Bayes (the easy way)

Prepare data set (can be skipped, it is already there)

In [4]:
import os
import shutil

source_root = '/Users/apismenskiy/git'
directory_path = os.path.join(source_root, 'source')
output_path = 'source'
java_path = os.path.join(output_path, 'java')
cpp_path = os.path.join(output_path, 'cpp')
scala_path = os.path.join(output_path, 'scala')
js_path = os.path.join(output_path, 'javascript')
py_path = os.path.join(output_path, 'python')
text_path = os.path.join(output_path, 'plaintext')

def find_files_with_extension(directory, extension):
    file_list = []

    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(extension):
                file_list.append(os.path.join(root, file))

    return file_list

def create_and_move_files(directory, extension, output_folder_name):
    found_files = find_files_with_extension(directory, extension)

    if not found_files:
        print(f"No files with '{extension}' extension found in the specified directory.")
        return

    if not os.path.exists(output_folder_name):
        os.makedirs(output_folder_name)
        print(f"Created folder '{output_folder_name}' to store files.")

    for file_path in found_files:
        new_file_path = os.path.join(output_folder_name, os.path.basename(file_path))
        shutil.move(file_path, new_file_path)
        print(f"Moved '{file_path}' to '{new_file_path}'")



create_and_move_files(os.path.join(source_root, 'tika'), 'java', java_path)
create_and_move_files(os.path.join(source_root, 'tesseract'), 'cpp', cpp_path)
create_and_move_files(os.path.join(source_root, 'playframework'), 'scala', scala_path)
create_and_move_files(os.path.join(source_root, 'jquery'), 'js', js_path)
create_and_move_files(os.path.join(source_root, 'scikit-learn'), 'py', py_path)


No files with 'java' extension found in the specified directory.
No files with 'cpp' extension found in the specified directory.
No files with 'scala' extension found in the specified directory.
No files with 'js' extension found in the specified directory.
No files with 'py' extension found in the specified directory.


We'll cheat by using sklearn.naive_bayes to train a source code classifier! Most of the code is just loading our training data into a pandas DataFrame that we can play with:

In [5]:
import os
import io
import numpy
import pandas as pd
from pandas import DataFrame
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB

def readFiles(path):
    for root, dirnames, filenames in os.walk(path):
        for filename in filenames:
            path = os.path.join(root, filename)

            inBody = False
            lines = []
            f = io.open(path, 'r', encoding='latin1')
            for line in f:
                if inBody:
                    lines.append(line)
                elif line == '\n':
                    inBody = True
            f.close()
            message = '\n'.join(lines)
            yield path, message


def dataFrameFromDirectory(path, classification):
    rows = []
    index = []
    for filename, message in readFiles(path):
        rows.append({'message': message, 'class': classification})
        index.append(filename)

    return DataFrame(rows, index=index)

data = DataFrame({'message': [], 'class': []})

data = pd.concat([data, dataFrameFromDirectory(cpp_path, "cpp")])
data = pd.concat([data, dataFrameFromDirectory(java_path, "java")])
data = pd.concat([data, dataFrameFromDirectory(js_path, "javascript")])
data = pd.concat([data, dataFrameFromDirectory(py_path, "python")])
data = pd.concat([data, dataFrameFromDirectory(scala_path, "scala")])
data = pd.concat([data, dataFrameFromDirectory(text_path, "text")])



Let's have a look at that DataFrame:

In [13]:
data.head()

Unnamed: 0,message,class
source/cpp/permdawg.cpp,"#include ""dawg.h""\n\n#include ""params.h""\n\n#i...",cpp
source/cpp/edgblob.cpp,// Include automatically generated configurati...,cpp
source/cpp/pithsync.cpp,"#include ""pithsync.h""\n\n\n\n#include ""makerow...",cpp
source/cpp/picofeat.cpp,"#include ""picofeat.h""\n\n\n\n#include ""classif...",cpp
source/cpp/tessvars.cpp,"#include <cstdio>\n\n\n\n#include ""tessvars.h""...",cpp


Now we will use a CountVectorizer to split up each message into its list of words, and throw that into a MultinomialNB classifier. Call fit() and we've got a trained spam filter ready to go! It's just that easy.

In [14]:
vectorizer = CountVectorizer()
counts = vectorizer.fit_transform(data['message'].values)

classifier = MultinomialNB()
targets = data['class'].values
classifier.fit(counts, targets)

Let's try it out:

In [16]:
examples = [
    # scala
    '''package org.threeten.bp

import java.util.NavigableMap
import org.threeten.bp.zone.ZoneMap

object Platform {
type NPE = NullPointerException
type DFE = IndexOutOfBoundsException
type CCE = ClassCastException

/**
* Returns `true` if and only if the code is executing on a JVM. Note: Returns `false` when
* executing on any JS VM.
*/
final val executingInJVM = true

def setupLocales(): Unit = {}

def zoneMap(m: scala.collection.immutable.TreeMap[Int, String]): NavigableMap[Int, String] =
ZoneMap(m)
}''',
    # java
    ''' public static void run() {

ProfileCredentialsProvider awsCredentialsProvider = ProfileCredentialsProvider.create();

CLIENT = TextractClient.builder()
        .region(region)
        .credentialsProvider(awsCredentialsProvider)
        .build();

String absolutePath = getAbsolutePath();
CATEGORIES.forEach(category -> {

    String path = absolutePath + DLMTR + DATA_ROOT + DLMTR + category;
    Set<Path> ocrFiles = getOcrFiles(path);
    System.out.println(path + ": Found image files: " + ocrFiles);''',
    # python
    '''class Polygon:
    def sides_no(self):
        pass

class Triangle(Polygon):
    def area(self):
        pass

obj_polygon = Polygon()
obj_triangle = Triangle()

print(type(obj_triangle) == Triangle)   	# true
print(type(obj_triangle) == Polygon)    	# false

print(isinstance(obj_polygon, Polygon)) 	# true
print(isinstance(obj_triangle, Polygon))	# true''',

    # cpp
    '''#include <iostream>
    #include <iostream>
using namespace std;

int main() {
int n;

cout << "Enter an integer: ";
cin >> n;

if ( n % 2 == 0)
cout << n << " is even.";
else
cout << n << " is odd.";

return 0;
}''',

# javascript
'''
console.log("Hello World");

var canvas = document.getElementById("canvas");
var c = canvas.getContext("2d");
var tx = window.innerWidth;
var ty = window.innerHeight;
canvas.width = tx;
canvas.height = ty;
//c.lineWidth= 5;
//c.globalAlpha = 0.5;

var mousex = 0;
var mousey = 0;

addEventListener("mousemove", function() {
  mousex = event.clientX;
  mousey = event.clientY;
});


var grav = 0.99;
c.strokeWidth=5;
function randomColor() {
  return (
    "rgba(" +
    Math.round(Math.random() * 250) +
    "," +
    Math.round(Math.random() * 250) +
    "," +
    Math.round(Math.random() * 250) +
    "," +
    Math.ceil(Math.random() * 10) / 10 +
    ")"
  );
}

function Ball() {
  this.color = randomColor();
  this.radius = Math.random() * 20 + 14;
  this.startradius = this.radius;
  this.x = Math.random() * (tx - this.radius * 2) + this.radius;
  this.y = Math.random() * (ty - this.radius);
  this.dy = Math.random() * 2;
  this.dx = Math.round((Math.random() - 0.5) * 10);
  this.vel = Math.random() /5;
  this.update = function() {
    c.beginPath();
    c.arc(this.x, this.y, this.radius, 0, 2 * Math.PI);
    c.fillStyle = this.color;
    c.fill();
    //c.stroke();
  };
}

var bal = [];
for (var i=0; i<50; i++){
    bal.push(new Ball());
}

function animate() {
  if (tx != window.innerWidth || ty != window.innerHeight) {
    tx = window.innerWidth;
    ty = window.innerHeight;
    canvas.width = tx;
    canvas.height = ty;
  }
  requestAnimationFrame(animate);
  c.clearRect(0, 0, tx, ty);
  for (var i = 0; i < bal.length; i++) {
    bal[i].update();
    bal[i].y += bal[i].dy;
    bal[i].x += bal[i].dx;
    if (bal[i].y + bal[i].radius >= ty) {
      bal[i].dy = -bal[i].dy * grav;
    } else {
      bal[i].dy += bal[i].vel;
    }
    if(bal[i].x + bal[i].radius > tx || bal[i].x - bal[i].radius < 0){
        bal[i].dx = -bal[i].dx;
    }
    if(mousex > bal[i].x - 20 &&
      mousex < bal[i].x + 20 &&
      mousey > bal[i].y -50 &&
      mousey < bal[i].y +50 &&
      bal[i].radius < 70){
        //bal[i].x += +1;
        bal[i].radius +=5;
      } else {
        if(bal[i].radius > bal[i].startradius){
          bal[i].radius += -5;
        }
      }

    //forloop end
    }
//animation end
}

animate();

setInterval(function() {
  bal.push(new Ball());
  bal.splice(0, 1);
}, 400);

''',
# plain text
            '''World War II or the Second World War, often abbreviated as WWII or WW2, was a global conflict lasted from 1939 to 1945. The vast majority of the world's countries, including all of the great powers, fought as part of two opposing military alliances: the Allies and the Axis. Many participants threw their economic, industrial, and scientific capabilities behind this total war, blurring the distinction between civilian and military resources. Aircraft played a major role, enabling the strategic bombing of population centres and the delivery of the only two nuclear weapons ever used in war. World War II was by far the deadliest conflict in history, resulting in an estimated 70 to 85 million fatalities, mostly among civilians. Tens of millions died due to genocides (including the Holocaust), starvation, massacres, and disease. In the wake of the Axis defeat, Germany and Japan were occupied, and war crimes tribunals were conducted against German and Japanese leaders.'''
            ]
example_counts = vectorizer.transform(examples)
predictions = classifier.predict(example_counts)
predictions

array(['scala', 'java', 'python', 'cpp', 'javascript', 'text'],
      dtype='<U10')

## Activity

Our data set is small, so our spam classifier isn't actually very good. Try running some different test emails through it and see if you get the results you expect.

TODO If you really want to challenge yourself, try applying train/test to this spam classifier - see how well it can predict some subset of the ham and spam emails.