In [91]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from typing import List, Dict

In [148]:
class timeSeries():
    def __init__(self, df: pd.DataFrame, countries: List[str], start: str, end: str):
        df.dateRep = pd.to_datetime(df.dateRep, format='%d/%m/%Y')
        self.rawData = df
        self.processedData = self.extract_countries_and_time(countries, start, end)


    def extract_countries_and_time(self, countries: List[str], start: str, end: str) -> pd.DataFrame:
        """
        :param str df_path: default = path
        :param list<str> countries
        :param str start: yyyy-mm-dd
        :param str end: yyyy-mm-dd 
        """
        df = self.rawData
        df = df[["dateRep", "day", "month", "year", "cases", "countriesAndTerritories"]]
        df = df[df.countriesAndTerritories.isin(countries)]
        df = df.sort_values(["year", "month","day"])
        df["cumCases"] = df.groupby(["countriesAndTerritories"])["cases"].cumsum(axis=0)

        df = df.reset_index(drop = True)
        if (start == "start") & (end == "end"):
            output_df = df
        elif (start == "start"):
            output_df = df[df.dateRep <= end]
        elif (end == "end"):
            output_df = df[df.dateRep >= start]
        else:
            output_df = df[(start <= df.dateRep) & (df.dateRep <= end)]
        return output_df.reset_index(drop = True)

    def plot(self) -> pd.DataFrame:
        df = self.processedData
        groups = df.groupby("countriesAndTerritories")

        fig, ax = plt.subplots(figsize = (16,8))
        ax.margins(0.05)

        output_dict = {}

        for name, group in groups:
            ax.plot(group.dateRep, group.cumCases, label=name)
            output_dict[name] = group

        ax.legend()
        plt.show()

        return output_dict
    

In [152]:
df = pd.read_csv("data/europe_timeseries.csv")
ts = timeSeries(df, ["Germany", "Italy"], "2020-02-20", "2020-04")
ts.processedData;